Updating the kvrouter to support roundrobin

for comparison, still withthe overloading checks
This commit is contained in:
Nicolas Patry 2025-01-29 12:40:26 +01:00
parent 6a88063cc2
commit 0a495ad118
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
4 changed files with 148 additions and 63 deletions

View File

@ -110,10 +110,15 @@ impl Allocator for RadixAllocator {
let suffix_blocks = suffix_len.div_ceil(self.block_size);
let prefix_len_uncached = prefill_tokens.as_ref().map(|p| p.len()).unwrap_or_default();
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
metrics::counter!("tgi_cache_hit", "allocator" => "radix")
.increment(prefix_len.try_into().expect("Can convert usize to u64"));
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into());
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(
prefix_len_uncached
.try_into()
.expect("Can convert usize to u64"),
);
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),

View File

@ -18,37 +18,25 @@ const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
#[derive(Clone)]
pub struct RoundRobin {
client: Client,
pub struct ContentAware {
trie: Arc<Mutex<Trie>>,
backends: Arc<Vec<String>>,
inqueue: Arc<Vec<AtomicUsize>>,
inflight: Arc<Vec<AtomicUsize>>,
factor: f32,
}
impl RoundRobin {
pub fn new(backends: Vec<String>) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let trie = Arc::new(Mutex::new(Trie::new()));
let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string())
.parse()
.unwrap_or(1.5);
Self {
inflight,
inqueue,
trie,
client,
factor,
backends: Arc::new(backends),
}
impl Default for ContentAware {
fn default() -> Self {
Self::new()
}
}
pub fn next(&mut self, key: &[u8]) -> usize {
impl ContentAware {
pub fn new() -> Self {
let trie = Arc::new(Mutex::new(Trie::new()));
Self { trie }
}
}
impl LoadBalancer for ContentAware {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
let mut trie = self.trie.lock().unwrap();
let (start, stop) = trie.insert(key);
let n = trie.count();
@ -67,38 +55,108 @@ impl RoundRobin {
assert!(rescaled_x >= 0.0);
assert!(rescaled_x <= 1.0);
println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}");
let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize;
let n: usize = (rescaled_x * (n_backends as f32)) as usize;
n
}
}
pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Response<Body> {
#[derive(Clone)]
pub struct RoundRobin {
current: Arc<AtomicUsize>,
}
impl Default for RoundRobin {
fn default() -> Self {
Self::new()
}
}
impl RoundRobin {
pub fn new() -> Self {
let current = Arc::new(AtomicUsize::new(0));
Self { current }
}
}
impl LoadBalancer for RoundRobin {
fn next(&mut self, _key: &[u8], _n_backends: usize) -> usize {
self.current.fetch_add(1, Ordering::Relaxed)
}
}
#[derive(Clone)]
pub struct OverloadHandler<T: LoadBalancer> {
client: Client,
load_balancer: T,
backends: Arc<Vec<String>>,
inqueue: Arc<Vec<AtomicUsize>>,
inflight: Arc<Vec<AtomicUsize>>,
factor: f32,
}
impl<T: LoadBalancer> OverloadHandler<T> {
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string())
.parse()
.unwrap_or(1.5);
let backends = Arc::new(backends);
Self {
load_balancer,
backends,
client,
factor,
inflight,
inqueue,
}
}
fn next(&mut self, key: &[u8]) -> usize {
// Get the backend URL
let index = self.load_balancer.next(key, self.backends.len());
let n = self.backends.len();
let mut index = index % n;
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
let mut inqueue = self.inqueue[index].load(Ordering::Relaxed);
for i in 0..n {
if (inqueue as f32) <= self.factor * inflight as f32 {
break;
}
if i == 0 {
eprintln!(
"Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"
);
}
index += 1;
index %= self.backends.len();
inflight = self.inflight[index].load(Ordering::Relaxed);
inqueue = self.inflight[index].load(Ordering::Relaxed);
}
index
}
}
pub trait LoadBalancer {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
}
pub async fn handler<T: LoadBalancer>(
State(mut state): State<OverloadHandler<T>>,
req: Request,
) -> Response<Body> {
// Get the next backend index
let limit = 2048usize;
let limit = 1024 * 1024;
let (parts, body) = req.into_parts();
// TODO
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
let index = state.next(&bytes);
// Get the backend URL
let n = state.backends.len();
let mut index = index % n;
let backend = &state.backends[index];
let mut inflight = state.inflight[index].load(Ordering::Relaxed);
let mut inqueue = state.inqueue[index].load(Ordering::Relaxed);
for i in 0..n {
if (inqueue as f32) <= state.factor * inflight as f32 {
break;
}
if i == 0 {
eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead");
}
index += 1;
index %= state.backends.len();
inflight = state.inflight[index].load(Ordering::Relaxed);
inqueue = state.inflight[index].load(Ordering::Relaxed);
}
state.inflight[index].fetch_add(1, Ordering::Relaxed);
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
@ -131,6 +189,7 @@ pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Respo
while let Some(raw_event) = response_stream.next().await {
if start{
eprintln!("Not inqueue");
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
start = false;
}

View File

@ -2,7 +2,7 @@ use axum::{
routing::Router,
routing::{get, post},
};
use kvrouter::{handler, RoundRobin};
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
#[tokio::main]
async fn main() {
@ -10,12 +10,16 @@ async fn main() {
let backends = vec![
"http://localhost:8000".to_string(),
"http://localhost:8001".to_string(),
"http://localhost:8002".to_string(),
"http://localhost:8003".to_string(),
];
// Create a new instance of the RoundRobinRouter
let router = RoundRobin::new(backends);
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
@ -27,4 +31,21 @@ async fn main() {
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
} else {
println!("Using Content aware");
let lb = ContentAware::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
};
}

View File

@ -50,7 +50,7 @@ export function get_options() {
throughput: {
executor: 'shared-iterations',
vus: 100,
iterations: 200,
iterations: 500,
maxDuration: '40s',
},
},