diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 223ac67b..cc617c04 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -110,10 +110,15 @@ impl Allocator for RadixAllocator { let suffix_blocks = suffix_len.div_ceil(self.block_size); + let prefix_len_uncached = prefill_tokens.as_ref().map(|p| p.len()).unwrap_or_default(); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); metrics::counter!("tgi_cache_hit", "allocator" => "radix") .increment(prefix_len.try_into().expect("Can convert usize to u64")); - metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into()); + metrics::counter!("tgi_cache_total", "allocator" => "radix").increment( + prefix_len_uncached + .try_into() + .expect("Can convert usize to u64"), + ); match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index 31d1d81f..ba0bf3c8 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -18,37 +18,25 @@ const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; type Client = hyper_util::client::legacy::Client; #[derive(Clone)] -pub struct RoundRobin { - client: Client, +pub struct ContentAware { trie: Arc>, - backends: Arc>, - inqueue: Arc>, - inflight: Arc>, - factor: f32, } -impl RoundRobin { - pub fn new(backends: Vec) -> Self { - let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) - .build(HttpConnector::new()); - let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); - let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); - let trie = Arc::new(Mutex::new(Trie::new())); - let factor: f32 = std::env::var(FACTOR_KEY) - .unwrap_or("1.5".to_string()) - .parse() - .unwrap_or(1.5); - Self { - inflight, - inqueue, - trie, - client, - factor, - backends: Arc::new(backends), - } +impl Default for ContentAware { + fn default() -> Self { + Self::new() } +} - pub fn next(&mut self, key: &[u8]) -> usize { +impl ContentAware { + pub fn new() -> Self { + let trie = Arc::new(Mutex::new(Trie::new())); + Self { trie } + } +} + +impl LoadBalancer for ContentAware { + fn next(&mut self, key: &[u8], n_backends: usize) -> usize { let mut trie = self.trie.lock().unwrap(); let (start, stop) = trie.insert(key); let n = trie.count(); @@ -67,38 +55,108 @@ impl RoundRobin { assert!(rescaled_x >= 0.0); assert!(rescaled_x <= 1.0); println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}"); - let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize; + let n: usize = (rescaled_x * (n_backends as f32)) as usize; n } } -pub async fn handler(State(mut state): State, req: Request) -> Response { +#[derive(Clone)] +pub struct RoundRobin { + current: Arc, +} + +impl Default for RoundRobin { + fn default() -> Self { + Self::new() + } +} + +impl RoundRobin { + pub fn new() -> Self { + let current = Arc::new(AtomicUsize::new(0)); + Self { current } + } +} + +impl LoadBalancer for RoundRobin { + fn next(&mut self, _key: &[u8], _n_backends: usize) -> usize { + self.current.fetch_add(1, Ordering::Relaxed) + } +} + +#[derive(Clone)] +pub struct OverloadHandler { + client: Client, + load_balancer: T, + backends: Arc>, + inqueue: Arc>, + inflight: Arc>, + factor: f32, +} + +impl OverloadHandler { + pub fn new(load_balancer: T, backends: Vec) -> Self { + let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()); + let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect()); + let factor: f32 = std::env::var(FACTOR_KEY) + .unwrap_or("1.5".to_string()) + .parse() + .unwrap_or(1.5); + let backends = Arc::new(backends); + Self { + load_balancer, + backends, + client, + factor, + inflight, + inqueue, + } + } + + fn next(&mut self, key: &[u8]) -> usize { + // Get the backend URL + let index = self.load_balancer.next(key, self.backends.len()); + let n = self.backends.len(); + let mut index = index % n; + + let mut inflight = self.inflight[index].load(Ordering::Relaxed); + let mut inqueue = self.inqueue[index].load(Ordering::Relaxed); + + for i in 0..n { + if (inqueue as f32) <= self.factor * inflight as f32 { + break; + } + if i == 0 { + eprintln!( + "Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead" + ); + } + index += 1; + index %= self.backends.len(); + inflight = self.inflight[index].load(Ordering::Relaxed); + inqueue = self.inflight[index].load(Ordering::Relaxed); + } + index + } +} + +pub trait LoadBalancer { + fn next(&mut self, key: &[u8], n_backends: usize) -> usize; +} + +pub async fn handler( + State(mut state): State>, + req: Request, +) -> Response { // Get the next backend index - let limit = 2048usize; + let limit = 1024 * 1024; let (parts, body) = req.into_parts(); // TODO let bytes = axum::body::to_bytes(body, limit).await.unwrap(); let index = state.next(&bytes); - // Get the backend URL - let n = state.backends.len(); - let mut index = index % n; let backend = &state.backends[index]; - - let mut inflight = state.inflight[index].load(Ordering::Relaxed); - let mut inqueue = state.inqueue[index].load(Ordering::Relaxed); - - for i in 0..n { - if (inqueue as f32) <= state.factor * inflight as f32 { - break; - } - if i == 0 { - eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"); - } - index += 1; - index %= state.backends.len(); - inflight = state.inflight[index].load(Ordering::Relaxed); - inqueue = state.inflight[index].load(Ordering::Relaxed); - } state.inflight[index].fetch_add(1, Ordering::Relaxed); state.inqueue[index].fetch_add(1, Ordering::Relaxed); @@ -131,6 +189,7 @@ pub async fn handler(State(mut state): State, req: Request) -> Respo while let Some(raw_event) = response_stream.next().await { if start{ eprintln!("Not inqueue"); + state.inqueue[index].fetch_sub(1, Ordering::Relaxed); start = false; } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 7a213e45..58aaefca 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -2,7 +2,7 @@ use axum::{ routing::Router, routing::{get, post}, }; -use kvrouter::{handler, RoundRobin}; +use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin}; #[tokio::main] async fn main() { @@ -10,21 +10,42 @@ async fn main() { let backends = vec![ "http://localhost:8000".to_string(), "http://localhost:8001".to_string(), + "http://localhost:8002".to_string(), + "http://localhost:8003".to_string(), ]; // Create a new instance of the RoundRobinRouter - let router = RoundRobin::new(backends); + if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { + println!("Using round robin"); + let lb = RoundRobin::new(); + // Create the Axum router + let router = OverloadHandler::new(lb, backends); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(router); - // Create the Axum router - let app = Router::new() - .route("/{*key}", get(handler)) - .route("/{*key}", post(handler)) - .with_state(router); + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); + } else { + println!("Using Content aware"); + let lb = ContentAware::new(); + // Create the Axum router + let router = OverloadHandler::new(lb, backends); + let app = Router::new() + .route("/{*key}", get(handler)) + .route("/{*key}", post(handler)) + .with_state(router); - // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); + }; } diff --git a/load_tests/common.js b/load_tests/common.js index d890bf67..4e3cee4a 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -50,7 +50,7 @@ export function get_options() { throughput: { executor: 'shared-iterations', vus: 100, - iterations: 200, + iterations: 500, maxDuration: '40s', }, },