diff --git a/Cargo.lock b/Cargo.lock index 92367d1ee..3a5845a77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4045,6 +4045,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14fa..b321278c1 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -156,6 +156,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b3..1cc173e33 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b0..129ceb9cc 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 68ddf00b9..cbcbff72a 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,15 +35,24 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { + let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { + matches!(prefix_caching.as_str(), "true" | "1") + } else { + false + }; let attention = if let Ok(attention) = std::env::var("ATTENTION") { attention .parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else if prefix_caching { + Attention::FlashInfer } else { Attention::Paged }; let block_size = if attention == Attention::FlashDecoding { 256 + } else if attention == Attention::FlashInfer { + 1 } else { 16 }; @@ -51,6 +60,7 @@ impl BackendV3 { let queue = Queue::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd859..05c2bd30d 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,16 +1,26 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,54 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -128,9 +119,92 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub(crate) trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index c407687b7..6282759e8 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -157,6 +157,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index afb13cdc3..2f78da034 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -245,6 +245,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f891692..c8fc55f88 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -2,6 +2,7 @@ mod backend; mod block_allocator; mod client; mod queue; +mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index b457389c4..13544235e 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -46,6 +46,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -57,6 +58,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -109,6 +111,7 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -117,6 +120,7 @@ async fn queue_task( let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -176,12 +180,19 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), @@ -305,7 +316,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front @@ -331,11 +345,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -372,6 +387,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + prefix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time @@ -480,6 +496,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -492,6 +510,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, truncate: 0, decoder_input_details: false, @@ -527,7 +546,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -543,7 +562,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -551,7 +570,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -583,7 +602,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -603,7 +622,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -636,14 +655,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -651,7 +670,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -684,7 +703,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -700,7 +719,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -725,7 +744,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -744,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 000000000..0464b9f8c --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,755 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use slotmap::{DefaultKey, SlotMap}; + +use crate::block_allocator::{Allocator, BlockAllocation}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + if window_size.is_some() { + unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + } + + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + + node_id + } else { + self.cache_blocks.root_id() + }; + + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len(); + let suffix_len = tokens - prefix_len as u32; + + match self.alloc_or_reclaim(suffix_len as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = blocks.clone(); + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let prefix_len = self + .cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could beevicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + self.insert_(self.root, tokens, blocks) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = child.key.shared_prefix_len(tokens); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + self.nodes.remove(node_id); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +/// Helper trait to get the length of the shared prefix of two sequences. +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::block_allocator::Allocator; + + use super::RadixAllocator; + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = super::RadixTrie::new(); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703f..7494d5b5d 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + prefix_len: 0, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 926c878ea..68eea7ac9 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -3,22 +3,23 @@ syntax = "proto3"; package generate.v3; service TextGenerationService { - /// Model Info - rpc Info (InfoRequest) returns (InfoResponse) {} - /// Service discovery - rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} - /// Empties batch cache - rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Remove requests from a cached batch - rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); - /// Warmup the model and compute max cache size - rpc Warmup (WarmupRequest) returns (WarmupResponse); - /// Prefill batch and decode first token - rpc Prefill (PrefillRequest) returns (PrefillResponse); - /// Decode token for a list of prefilled batches - rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Health check - rpc Health (HealthRequest) returns (HealthResponse); + /// Model Info + rpc Info(InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery(ServiceDiscoveryRequest) + returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup(WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill(PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode(DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health(HealthRequest) returns (HealthResponse); } message HealthRequest {} @@ -28,240 +29,239 @@ message HealthResponse {} message InfoRequest {} message InfoResponse { - bool requires_padding = 1; - string dtype = 2; - string device_type = 3; - optional uint32 window_size = 4; - uint32 speculate = 5; + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; } /// Empty request message ServiceDiscoveryRequest {} message ServiceDiscoveryResponse { - /// Other shards urls - repeated string urls = 1; + /// Other shards urls + repeated string urls = 1; } message ClearCacheRequest { - /// Optional batch id - optional uint64 id = 1; + /// Optional batch id + optional uint64 id = 1; } /// Empty response message ClearCacheResponse {} message Image { - /// Binary image data. - bytes data = 1; + /// Binary image data. + bytes data = 1; - /// Image MIME type. - string mimetype = 2; + /// Image MIME type. + string mimetype = 2; } message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } } -message Input { - repeated InputChunk chunks = 1; - } +message Input { repeated InputChunk chunks = 1; } enum GrammarType { - GRAMMAR_TYPE_NONE = 0; - GRAMMAR_TYPE_JSON = 1; - GRAMMAR_TYPE_REGEX = 2; + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; } message NextTokenChooserParameters { - /// exponential scaling output probability distribution - float temperature = 1; - /// restricting to the k highest probability elements - uint32 top_k = 2; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float top_p = 3; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float typical_p = 4; - /// apply sampling on the logits - bool do_sample = 5; - /// random seed for sampling - uint64 seed = 6; - /// repetition penalty - float repetition_penalty = 7; - /// frequency penalty - float frequency_penalty = 9; - /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 8; - /// grammar (applied if not empty) - string grammar = 10; - /// grammar type - GrammarType grammar_type = 11; + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { - /// Maximum number of generated tokens - uint32 max_new_tokens = 1; - /// Optional stopping sequences - repeated string stop_sequences = 2; - /// Ignore end of sequence token - /// used for benchmarking - bool ignore_eos_token = 3; + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; } message Request { - /// Request ID - uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks - string inputs = 2; - /// Context truncation - uint32 truncate = 3; - /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 4; - /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 5; - /// Return prefill logprobs - bool prefill_logprobs = 6; - /// Return most likely n tokens - uint32 top_n_tokens = 7; - /// Paged attention blocks - repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; - /// LORA adapter index - optional string adapter_id = 11; + /// Request ID + uint64 id = 1; + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; + /// Prefix length that can be retrieved from the KV cache. + uint32 prefix_len = 12; } message Batch { - /// Batch ID - uint64 id = 1; - /// Individual requests - repeated Request requests = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; - /// Maximum number of Paged Attention blocks - uint32 max_blocks = 5; + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { - /// Batch ID - uint64 id = 1; - /// Individual requests ids - repeated uint64 request_ids = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; } enum FinishReason { - FINISH_REASON_LENGTH = 0; - FINISH_REASON_EOS_TOKEN = 1; - FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; } message GeneratedText { - /// Output - string text = 1; - /// Number of generated tokens - uint32 generated_tokens = 2; - /// Finish reason - FinishReason finish_reason = 3; - /// Seed - optional uint64 seed = 4; + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; } message Tokens { - /// Token IDs - repeated uint32 ids = 1; - /// Logprobs - repeated float logprobs = 2; - /// tokens - repeated string texts = 3; - /// special - repeated bool is_special = 4; + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; } message Generation { - /// Request ID - uint64 request_id = 1; - /// Prefill tokens (optional) - Tokens prefill_tokens = 2; - Tokens tokens = 3; - /// Complete generated text - optional GeneratedText generated_text = 4; - /// Top tokens - repeated Tokens top_tokens = 5; + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; } message FilterBatchRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to keep - repeated uint64 request_ids = 2; + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; } message FilterBatchResponse { - /// Filtered Batch (cached) - CachedBatch batch = 1; + /// Filtered Batch (cached) + CachedBatch batch = 1; } - message PrefillRequest { - /// Batch - Batch batch = 1; + /// Batch + Batch batch = 1; } message PrefillResponse { - /// Generation - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; } message DecodeRequest { - /// Cached batches - repeated CachedBatch batches = 1; + /// Cached batches + repeated CachedBatch batches = 1; } message DecodeResponse { - /// Decodes - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; - /// Concatenate elapsed time in nanoseconds - optional uint64 concat_ns = 6; + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message WarmupRequest { - /// Batch to warmup on - Batch batch = 1; - uint32 max_input_length = 2; - uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } message WarmupResponse { - /// Maximum number of tokens supported by the model - optional uint32 max_supported_total_tokens = 1; + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103f..5011158af 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -115,13 +116,14 @@ impl Validation { } } + #[allow(clippy::type_complexity)] #[instrument(skip(self, inputs))] async fn validate_input( &self, inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -156,8 +158,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -180,7 +184,12 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs)], + None, + input_length, + max_new_tokens, + )) } } @@ -314,7 +323,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -391,6 +400,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -707,6 +717,7 @@ pub struct ValidStoppingParameters { #[derive(Debug, Clone)] pub struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b58a5b802..abc354212 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,29 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -ATTENTION = os.getenv("ATTENTION", "paged") +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) +log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") + +ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") +if PREFIX_CACHING and ATTENTION != "flashinfer": + raise RuntimeError("Prefix caching is only supported with flashinfer") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + # This is overridden by the cli -BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 cuda_graphs = os.getenv("CUDA_GRAPHS")