This commit is contained in:
Nicolas Patry 2024-08-22 14:34:12 +02:00
parent 0bf4eb9683
commit 5eb6ea0063
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
6 changed files with 108 additions and 147 deletions

View File

@ -35,20 +35,14 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { let prefix_caching =
matches!(prefix_caching.as_str(), "true" | "1") std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
} else { let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
true let attention: String = std::env::var("ATTENTION").expect("attention env var");
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") { let attention: Attention = attention
attention .parse()
.parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else {
Attention::FlashDecoding
};
let block_size = if attention == Attention::FlashDecoding { let block_size = if attention == Attention::FlashDecoding {
256 256
} else if attention == Attention::FlashInfer { } else if attention == Attention::FlashInfer {

View File

@ -1,4 +1,4 @@
use std::{cmp::min, sync::Arc}; use std::sync::Arc;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator; use crate::radix::RadixAllocator;
@ -91,11 +91,7 @@ async fn block_allocator_task(
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching { let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching);
Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { BlockAllocatorCommand::Free {
@ -128,83 +124,12 @@ enum BlockAllocatorCommand {
}, },
} }
pub trait Allocator { // pub trait Allocator {
fn allocate( // fn allocate(
&mut self, // &mut self,
tokens: u32, // tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>, // prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>; // ) -> Option<BlockAllocation>;
//
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64); // fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
} // }
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}

View File

@ -5,7 +5,7 @@ use std::{
use slotmap::{DefaultKey, SlotMap}; use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation}; use crate::block_allocator::BlockAllocation;
pub struct RadixAllocator { pub struct RadixAllocator {
allocation_id: u64, allocation_id: u64,
@ -21,10 +21,18 @@ pub struct RadixAllocator {
// This isn't used because the prefix need to match without the windowing // This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong. // mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>, window_size: Option<u32>,
/// Wether to actual use the radix tree for searching or not.
prefix_caching: bool,
} }
impl RadixAllocator { impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self { pub fn new(
block_size: u32,
n_blocks: u32,
window_size: Option<u32>,
prefix_caching: bool,
) -> Self {
assert_eq!( assert_eq!(
block_size, 1, block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}", "Radix tree allocator only works with block_size=1, was: {}",
@ -42,6 +50,7 @@ impl RadixAllocator {
// Block 0 is reserved for health checks. // Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(), free_blocks: (1..n_blocks).collect(),
window_size, window_size,
prefix_caching,
} }
} }
@ -69,23 +78,25 @@ impl RadixAllocator {
} }
} }
impl Allocator for RadixAllocator { // Allocator trait
fn allocate( impl RadixAllocator {
pub fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>, prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> { ) -> Option<BlockAllocation> {
let mut blocks = vec![]; let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) {
let node_id = self (true, Some(prefill_tokens)) => {
.cache_blocks let node_id = self
.find(prefill_tokens.as_slice(), &mut blocks); .cache_blocks
// Even if this allocation fails below, we need to increase he .find(prefill_tokens.as_slice(), &mut blocks);
// refcount to ensure that the prefix that was found is not evicted. // Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id node_id
} else { }
self.cache_blocks.root_id() _ => self.cache_blocks.root_id(),
}; };
self.cache_blocks self.cache_blocks
@ -126,7 +137,7 @@ impl Allocator for RadixAllocator {
}) })
} }
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) { pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) { let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation, Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."), None => unreachable!("Tried to free an unknown allocation."),
@ -574,13 +585,11 @@ where
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use crate::block_allocator::Allocator;
use super::RadixAllocator; use super::RadixAllocator;
#[test] #[test]
fn allocator_reuses_prefixes() { fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None); let mut cache = RadixAllocator::new(1, 12, None, true);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots); assert_eq!(allocation.slots, allocation.slots);
@ -592,9 +601,23 @@ mod tests {
assert_eq!(allocation.prefix_len, 4); assert_eq!(allocation.prefix_len, 4);
} }
#[test]
fn allocator_doesnt_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None, false);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]);
assert_eq!(allocation.prefix_len, 0);
}
#[test] #[test]
fn allocator_collects_older_prefixes_first() { fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None); let mut cache = RadixAllocator::new(1, 7, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0); assert_eq!(allocation1.prefix_len, 0);
@ -614,7 +637,7 @@ mod tests {
#[test] #[test]
fn allocator_frees_fully_overlapping_prefills() { fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None); let mut cache = RadixAllocator::new(1, 10, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
@ -630,7 +653,7 @@ mod tests {
#[test] #[test]
fn allocator_frees_partially_overlapping_prefills() { fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None); let mut cache = RadixAllocator::new(1, 20, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0); assert_eq!(allocation1.prefix_len, 0);

View File

@ -24,6 +24,44 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
fn resolve_attention(config: &Config, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if prefix_caching.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
prefix_caching = Some("0".to_string());
attention = Some("flashdecoding".to_string());
}
}
_ => {}
}
}
_ => {
if prefix_caching.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
prefix_caching = Some("0".to_string());
attention = Some("flashdecoding".to_string());
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct RawConfig { struct RawConfig {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
@ -1496,28 +1534,10 @@ fn main() -> Result<(), LauncherError> {
let config: RawConfig = serde_json::from_str(&content)?; let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into(); let config: Config = config.into();
match config.head_dim { let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
Some(h) if h == 64 || h == 128 || h == 256 => { tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
if args.lora_adapters.is_some() { std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
tracing::info!("Disabling prefix caching because of lora adapters"); std::env::set_var("ATTENTION", attention);
std::env::set_var("USE_PREFIX_CACHING", "0");
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
std::env::set_var("USE_PREFIX_CACHING", "0");
std::env::set_var("ATTENTION", "flashdecoding");
}
_ => {}
}
}
_ => {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
std::env::set_var("USE_PREFIX_CACHING", "0");
std::env::set_var("ATTENTION", "flashdecoding");
}
}
let quantize = config.quantize; let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained. // Quantization usually means you're even more RAM constrained.

View File

@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -266,7 +265,7 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
if PREFIX_CACHING: if ATTENTION == "flashinfer":
prefix_len = r.prefix_len prefix_len = r.prefix_len
if prefix_len == orig_input_length: if prefix_len == orig_input_length:
assert prefix_len > 0 assert prefix_len > 0
@ -1452,7 +1451,7 @@ class FlashCausalLM(Model):
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING: if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,

View File

@ -5,9 +5,9 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "flashdecoding") ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected