This commit is contained in:
Nicolas Patry 2024-08-22 14:34:12 +02:00
parent 0bf4eb9683
commit 5eb6ea0063
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
6 changed files with 108 additions and 147 deletions

View File

@ -35,20 +35,14 @@ impl BackendV3 {
window_size: Option<u32>,
speculate: u32,
) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
true
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else {
Attention::FlashDecoding
};
let prefix_caching =
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").expect("attention env var");
let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = if attention == Attention::FlashDecoding {
256
} else if attention == Attention::FlashInfer {

View File

@ -1,4 +1,4 @@
use std::{cmp::min, sync::Arc};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
@ -91,11 +91,7 @@ async fn block_allocator_task(
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching);
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free {
@ -128,83 +124,12 @@ enum BlockAllocatorCommand {
},
}
pub trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}
// pub trait Allocator {
// fn allocate(
// &mut self,
// tokens: u32,
// prefill_tokens: Option<Arc<Vec<u32>>>,
// ) -> Option<BlockAllocation>;
//
// fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
// }

View File

@ -5,7 +5,7 @@ use std::{
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation};
use crate::block_allocator::BlockAllocation;
pub struct RadixAllocator {
allocation_id: u64,
@ -21,10 +21,18 @@ pub struct RadixAllocator {
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
/// Wether to actual use the radix tree for searching or not.
prefix_caching: bool,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
pub fn new(
block_size: u32,
n_blocks: u32,
window_size: Option<u32>,
prefix_caching: bool,
) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
@ -42,6 +50,7 @@ impl RadixAllocator {
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
window_size,
prefix_caching,
}
}
@ -69,23 +78,25 @@ impl RadixAllocator {
}
}
impl Allocator for RadixAllocator {
fn allocate(
// Allocator trait
impl RadixAllocator {
pub fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) {
(true, Some(prefill_tokens)) => {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id
} else {
self.cache_blocks.root_id()
node_id
}
_ => self.cache_blocks.root_id(),
};
self.cache_blocks
@ -126,7 +137,7 @@ impl Allocator for RadixAllocator {
})
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
@ -574,13 +585,11 @@ where
mod tests {
use std::sync::Arc;
use crate::block_allocator::Allocator;
use super::RadixAllocator;
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let mut cache = RadixAllocator::new(1, 12, None, true);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
@ -592,9 +601,23 @@ mod tests {
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_doesnt_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None, false);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]);
assert_eq!(allocation.prefix_len, 0);
}
#[test]
fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let mut cache = RadixAllocator::new(1, 7, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0);
@ -614,7 +637,7 @@ mod tests {
#[test]
fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None);
let mut cache = RadixAllocator::new(1, 10, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
@ -630,7 +653,7 @@ mod tests {
#[test]
fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None);
let mut cache = RadixAllocator::new(1, 20, None, true);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0);

View File

@ -24,6 +24,44 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
fn resolve_attention(config: &Config, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if prefix_caching.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
prefix_caching = Some("0".to_string());
attention = Some("flashdecoding".to_string());
}
}
_ => {}
}
}
_ => {
if prefix_caching.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
prefix_caching = Some("0".to_string());
attention = Some("flashdecoding".to_string());
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
@ -1496,28 +1534,10 @@ fn main() -> Result<(), LauncherError> {
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if args.lora_adapters.is_some() {
tracing::info!("Disabling prefix caching because of lora adapters");
std::env::set_var("USE_PREFIX_CACHING", "0");
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
std::env::set_var("USE_PREFIX_CACHING", "0");
std::env::set_var("ATTENTION", "flashdecoding");
}
_ => {}
}
}
_ => {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
std::env::set_var("USE_PREFIX_CACHING", "0");
std::env::set_var("ATTENTION", "flashdecoding");
}
}
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained.

View File

@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index,
)
from text_generation_server.layers.attention import Seqlen
@ -266,7 +265,7 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
@ -1452,7 +1451,7 @@ class FlashCausalLM(Model):
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,

View File

@ -5,9 +5,9 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "flashdecoding")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected