From 5eb6ea006393812770fd9edd0bc2d7ed937efe0e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Aug 2024 14:34:12 +0200 Subject: [PATCH] Tmp --- backends/v3/src/backend.rs | 22 ++--- backends/v3/src/block_allocator.rs | 97 +++---------------- backends/v3/src/radix.rs | 63 ++++++++---- launcher/src/main.rs | 64 +++++++----- .../models/flash_causal_lm.py | 5 +- .../text_generation_server/models/globals.py | 4 +- 6 files changed, 108 insertions(+), 147 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 396287d32..ec80d55a3 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,20 +35,14 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { - matches!(prefix_caching.as_str(), "true" | "1") - } else { - true - }; - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else if prefix_caching { - Attention::FlashInfer - } else { - Attention::FlashDecoding - }; + let prefix_caching = + std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); + let attention: String = std::env::var("ATTENTION").expect("attention env var"); + + let attention: Attention = attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); let block_size = if attention == Attention::FlashDecoding { 256 } else if attention == Attention::FlashInfer { diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index c5503b8c6..1cbc58684 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,4 +1,4 @@ -use std::{cmp::min, sync::Arc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; @@ -91,11 +91,7 @@ async fn block_allocator_task( window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - let mut allocator: Box = if prefix_caching { - Box::new(RadixAllocator::new(block_size, blocks, window_size)) - } else { - Box::new(SimpleAllocator::new(blocks, block_size, window_size)) - }; + let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching); while let Some(cmd) = receiver.recv().await { match cmd { BlockAllocatorCommand::Free { @@ -128,83 +124,12 @@ enum BlockAllocatorCommand { }, } -pub trait Allocator { - fn allocate( - &mut self, - tokens: u32, - prefill_tokens: Option>>, - ) -> Option; - - fn free(&mut self, blocks: Vec, allocation_id: u64); -} - -pub struct SimpleAllocator { - free_blocks: Vec, - block_size: u32, - window_size: Option, -} - -impl SimpleAllocator { - fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { - SimpleAllocator { - block_size, - // Block 0 is reserved for health checks - free_blocks: (1..blocks).collect(), - window_size, - } - } -} - -impl Allocator for SimpleAllocator { - fn allocate( - &mut self, - tokens: u32, - _prefill_tokens: Option>>, - ) -> Option { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match self.window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + self.block_size - 1) / self.block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - if required_blocks > self.free_blocks.len() as u32 { - None - } else { - let blocks = self - .free_blocks - .split_off(self.free_blocks.len() - required_blocks as usize); - let mut slots = - Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some(BlockAllocation { - allocation_id: 0, - blocks, - slots, - prefix_len: 0, - block_allocator: None, - }) - } - } - - fn free(&mut self, blocks: Vec, _allocation_id: u64) { - self.free_blocks.extend(blocks) - } -} +// pub trait Allocator { +// fn allocate( +// &mut self, +// tokens: u32, +// prefill_tokens: Option>>, +// ) -> Option; +// +// fn free(&mut self, blocks: Vec, allocation_id: u64); +// } diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index c9ac12c23..0dfaf0ab3 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -5,7 +5,7 @@ use std::{ use slotmap::{DefaultKey, SlotMap}; -use crate::block_allocator::{Allocator, BlockAllocation}; +use crate::block_allocator::BlockAllocation; pub struct RadixAllocator { allocation_id: u64, @@ -21,10 +21,18 @@ pub struct RadixAllocator { // This isn't used because the prefix need to match without the windowing // mecanism. This at worst is overallocating, not necessarily being wrong. window_size: Option, + + /// Wether to actual use the radix tree for searching or not. + prefix_caching: bool, } impl RadixAllocator { - pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + pub fn new( + block_size: u32, + n_blocks: u32, + window_size: Option, + prefix_caching: bool, + ) -> Self { assert_eq!( block_size, 1, "Radix tree allocator only works with block_size=1, was: {}", @@ -42,6 +50,7 @@ impl RadixAllocator { // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), window_size, + prefix_caching, } } @@ -69,23 +78,25 @@ impl RadixAllocator { } } -impl Allocator for RadixAllocator { - fn allocate( +// Allocator trait +impl RadixAllocator { + pub fn allocate( &mut self, tokens: u32, prefill_tokens: Option>>, ) -> Option { let mut blocks = vec![]; - let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { - let node_id = self - .cache_blocks - .find(prefill_tokens.as_slice(), &mut blocks); - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. + let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) { + (true, Some(prefill_tokens)) => { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. - node_id - } else { - self.cache_blocks.root_id() + node_id + } + _ => self.cache_blocks.root_id(), }; self.cache_blocks @@ -126,7 +137,7 @@ impl Allocator for RadixAllocator { }) } - fn free(&mut self, blocks: Vec, allocation_id: u64) { + pub fn free(&mut self, blocks: Vec, allocation_id: u64) { let allocation = match self.allocations.remove(&allocation_id) { Some(allocation) => allocation, None => unreachable!("Tried to free an unknown allocation."), @@ -574,13 +585,11 @@ where mod tests { use std::sync::Arc; - use crate::block_allocator::Allocator; - use super::RadixAllocator; #[test] fn allocator_reuses_prefixes() { - let mut cache = RadixAllocator::new(1, 12, None); + let mut cache = RadixAllocator::new(1, 12, None, true); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.slots, allocation.slots); @@ -592,9 +601,23 @@ mod tests { assert_eq!(allocation.prefix_len, 4); } + #[test] + fn allocator_doesnt_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None, false); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]); + assert_eq!(allocation.prefix_len, 0); + } + #[test] fn allocator_collects_older_prefixes_first() { - let mut cache = RadixAllocator::new(1, 7, None); + let mut cache = RadixAllocator::new(1, 7, None, true); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.prefix_len, 0); @@ -614,7 +637,7 @@ mod tests { #[test] fn allocator_frees_fully_overlapping_prefills() { - let mut cache = RadixAllocator::new(1, 10, None); + let mut cache = RadixAllocator::new(1, 10, None, true); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); @@ -630,7 +653,7 @@ mod tests { #[test] fn allocator_frees_partially_overlapping_prefills() { - let mut cache = RadixAllocator::new(1, 20, None); + let mut cache = RadixAllocator::new(1, 20, None, true); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.prefix_len, 0); diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 627dbd140..35dde5d09 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -24,6 +24,44 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +fn resolve_attention(config: &Config, lora_adapters: &Option) -> (String, String) { + let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); + let mut attention: Option = std::env::var("ATTENTION").ok(); + match config.head_dim { + Some(h) if h == 64 || h == 128 || h == 256 => { + if lora_adapters.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of lora adapters"); + prefix_caching = Some("0".to_string()); + } + match config.model_type.as_deref() { + Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { + // Required because gemma2 needs bfloat16 which is not supported by + // flashinfer ? + if prefix_caching.is_none() { + tracing::info!( + "Forcing flash decoding because model {} requires it", + config.model_type.as_ref().unwrap() + ); + prefix_caching = Some("0".to_string()); + attention = Some("flashdecoding".to_string()); + } + } + _ => {} + } + } + _ => { + if prefix_caching.is_none() { + tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); + prefix_caching = Some("0".to_string()); + attention = Some("flashdecoding".to_string()); + } + } + } + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + let attention = attention.unwrap_or("flashinfer".to_string()); + (prefix_caching, attention) +} + #[derive(Deserialize)] struct RawConfig { max_position_embeddings: Option, @@ -1496,28 +1534,10 @@ fn main() -> Result<(), LauncherError> { let config: RawConfig = serde_json::from_str(&content)?; let config: Config = config.into(); - match config.head_dim { - Some(h) if h == 64 || h == 128 || h == 256 => { - if args.lora_adapters.is_some() { - tracing::info!("Disabling prefix caching because of lora adapters"); - std::env::set_var("USE_PREFIX_CACHING", "0"); - } - match config.model_type.as_deref() { - Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { - // Required because gemma2 needs bfloat16 which is not supported by - // flashinfer ? - std::env::set_var("USE_PREFIX_CACHING", "0"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - _ => {} - } - } - _ => { - tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - std::env::set_var("USE_PREFIX_CACHING", "0"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - } + let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); + tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); + std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("ATTENTION", attention); let quantize = config.quantize; // Quantization usually means you're even more RAM constrained. diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dd4203e06..3d962bede 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,6 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, - PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -266,7 +265,7 @@ class FlashCausalLMBatch(Batch): orig_input_length = len(tokenized_input) - if PREFIX_CACHING: + if ATTENTION == "flashinfer": prefix_len = r.prefix_len if prefix_len == orig_input_length: assert prefix_len > 0 @@ -1452,7 +1451,7 @@ class FlashCausalLM(Model): if cu_seqlen_prefill is not None or cuda_graph is None: input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 5dc8b6852..aaed2475a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"} +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "flashdecoding") +ATTENTION = os.getenv("ATTENTION") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected