From 1d0847a90e6bee759ab0af1cc854e78b09892d93 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 7 Sep 2024 01:19:16 +0200 Subject: [PATCH] Revert the max prefix hit. --- backends/v3/src/backend.rs | 2 + backends/v3/src/radix.rs | 71 ++++++------------- .../models/flash_causal_lm.py | 28 +++++--- 3 files changed, 43 insertions(+), 58 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 05a26370..935f7980 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -168,6 +168,8 @@ pub(crate) async fn batching_task( None } else { // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 7fe732f1..1f3bef15 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -70,12 +70,9 @@ impl Allocator for RadixAllocator { ) -> Option { let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { - let node_id = self.cache_blocks.find( - // XXX This is super important we cannot match an entire prefix - // otherwise input_ids is empty and the shard code cannot handle that. - &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)], - &mut blocks, - ); + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); node_id } else { self.cache_blocks.root_id() @@ -92,6 +89,8 @@ impl Allocator for RadixAllocator { let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { @@ -149,9 +148,7 @@ impl Allocator for RadixAllocator { .expect("Failed to decrement refcount"); if let Some(prefill_tokens) = allocation.prefill_tokens { - // XXX We matched everything except the last token - let prefill_tokens = - &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)]; + let prefill_tokens = prefill_tokens.as_slice(); // If there are prefill tokens that did not come from the cache, // add them to the cache. @@ -613,17 +610,13 @@ mod tests { #[test] fn allocator_block_size() { let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.prefix_len, 4); @@ -632,17 +625,13 @@ mod tests { #[test] fn allocator_block_size_non_aligned() { let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache - .allocate(7, Some(Arc::new(vec![0, 1, 2, 99]))) - .unwrap(); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache - .allocate(7, Some(Arc::new(vec![0, 1, 2, 99]))) - .unwrap(); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.prefix_len, 2); @@ -651,17 +640,13 @@ mod tests { #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); - let allocation = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.prefix_len, 4); } @@ -669,13 +654,11 @@ mod tests { #[test] fn allocator_collects_older_prefixes_first() { let mut cache = RadixAllocator::new(1, 7, None); - let allocation1 = cache - .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.prefix_len, 0); - let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5, 99]))).unwrap(); + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); assert_eq!(allocation2.blocks, vec![1, 2]); assert_eq!(allocation2.prefix_len, 0); @@ -683,9 +666,7 @@ mod tests { cache.free(allocation2.blocks.clone(), allocation2.allocation_id); // We should get the blocks of the first allocation, since they are more recent. - let allocation3 = cache - .allocate(4, Some(Arc::new(vec![6, 7, 8, 9, 99]))) - .unwrap(); + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation3.prefix_len, 0); } @@ -693,19 +674,13 @@ mod tests { #[test] fn allocator_frees_fully_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 10, None); - let allocation1 = cache - .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); - let allocation2 = cache - .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); cache.free(allocation2.blocks.clone(), allocation2.allocation_id); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - let allocation3 = cache - .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) - .unwrap(); + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation3.prefix_len, 4); // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. @@ -715,20 +690,20 @@ mod tests { #[test] fn allocator_frees_partially_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 20, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 99]))).unwrap(); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.prefix_len, 0); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); let allocation2 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) .unwrap(); assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); assert_eq!(allocation2.prefix_len, 2); let allocation3 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) .unwrap(); assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation3.prefix_len, 2); @@ -740,14 +715,14 @@ mod tests { assert_eq!(cache.free_blocks.len(), 11); let allocation4 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) .unwrap(); assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); assert_eq!(allocation4.prefix_len, 6); assert_eq!(cache.free_blocks.len(), 11); let allocation5 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) .unwrap(); assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); assert_eq!(allocation5.prefix_len, 6); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index fe77257b..0ac55b42 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -268,6 +268,9 @@ class FlashCausalLMBatch(Batch): assert ( prefix_len <= orig_input_length ), f"Prefix {prefix_len} vs input {orig_input_length}" + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:] @@ -1157,13 +1160,6 @@ class FlashCausalLM(Model): "input_lengths": input_lengths_tensor, "prefix_lengths": prefix_lengths_tensor, } - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - prefix_lengths=prefix_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1199,6 +1195,13 @@ class FlashCausalLM(Model): prefix_lens=prefix_lengths, prefix_lens_tensor=prefix_lengths_tensor, ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1215,6 +1218,13 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1517,9 +1527,7 @@ class FlashCausalLM(Model): cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["prefix_lengths"].zero_() cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor