mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Tmp
This commit is contained in:
parent
0bf4eb9683
commit
5eb6ea0063
@ -35,20 +35,14 @@ impl BackendV3 {
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
|
||||
matches!(prefix_caching.as_str(), "true" | "1")
|
||||
} else {
|
||||
true
|
||||
};
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else if prefix_caching {
|
||||
Attention::FlashInfer
|
||||
} else {
|
||||
Attention::FlashDecoding
|
||||
};
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else if attention == Attention::FlashInfer {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{cmp::min, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::radix::RadixAllocator;
|
||||
@ -91,11 +91,7 @@ async fn block_allocator_task(
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching);
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free {
|
||||
@ -128,83 +124,12 @@ enum BlockAllocatorCommand {
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
||||
// pub trait Allocator {
|
||||
// fn allocate(
|
||||
// &mut self,
|
||||
// tokens: u32,
|
||||
// prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
// ) -> Option<BlockAllocation>;
|
||||
//
|
||||
// fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
// }
|
||||
|
@ -5,7 +5,7 @@ use std::{
|
||||
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
|
||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use crate::block_allocator::BlockAllocation;
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
@ -21,10 +21,18 @@ pub struct RadixAllocator {
|
||||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Wether to actual use the radix tree for searching or not.
|
||||
prefix_caching: bool,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
pub fn new(
|
||||
block_size: u32,
|
||||
n_blocks: u32,
|
||||
window_size: Option<u32>,
|
||||
prefix_caching: bool,
|
||||
) -> Self {
|
||||
assert_eq!(
|
||||
block_size, 1,
|
||||
"Radix tree allocator only works with block_size=1, was: {}",
|
||||
@ -42,6 +50,7 @@ impl RadixAllocator {
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
prefix_caching,
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,23 +78,25 @@ impl RadixAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
// Allocator trait
|
||||
impl RadixAllocator {
|
||||
pub fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) {
|
||||
(true, Some(prefill_tokens)) => {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
node_id
|
||||
}
|
||||
_ => self.cache_blocks.root_id(),
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
@ -126,7 +137,7 @@ impl Allocator for RadixAllocator {
|
||||
})
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
@ -574,13 +585,11 @@ where
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::block_allocator::Allocator;
|
||||
|
||||
use super::RadixAllocator;
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
let mut cache = RadixAllocator::new(1, 12, None, true);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, allocation.slots);
|
||||
@ -592,9 +601,23 @@ mod tests {
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_doesnt_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None, false);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_collects_older_prefixes_first() {
|
||||
let mut cache = RadixAllocator::new(1, 7, None);
|
||||
let mut cache = RadixAllocator::new(1, 7, None, true);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
@ -614,7 +637,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_fully_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 10, None);
|
||||
let mut cache = RadixAllocator::new(1, 10, None, true);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
|
||||
@ -630,7 +653,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_partially_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let mut cache = RadixAllocator::new(1, 20, None, true);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
@ -24,6 +24,44 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
fn resolve_attention(config: &Config, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
if prefix_caching.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
prefix_caching = Some("0".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if prefix_caching.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
prefix_caching = Some("0".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
(prefix_caching, attention)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawConfig {
|
||||
max_position_embeddings: Option<usize>,
|
||||
@ -1496,28 +1534,10 @@ fn main() -> Result<(), LauncherError> {
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
let config: Config = config.into();
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if args.lora_adapters.is_some() {
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
}
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("ATTENTION", attention);
|
||||
let quantize = config.quantize;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
|
@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
PREFIX_CACHING,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
@ -266,7 +265,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
if PREFIX_CACHING:
|
||||
if ATTENTION == "flashinfer":
|
||||
prefix_len = r.prefix_len
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
@ -1452,7 +1451,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
|
@ -5,9 +5,9 @@ from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "1").lower() in {"1", "true"}
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "flashdecoding")
|
||||
ATTENTION = os.getenv("ATTENTION")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
|
Loading…
Reference in New Issue
Block a user