Cleaner version.

This commit is contained in:
Nicolas Patry 2025-01-31 09:07:31 +01:00
parent 1932c5b9ed
commit 57fa04adfd
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
4 changed files with 158 additions and 79 deletions

5
Cargo.lock generated
View File

@ -2254,6 +2254,7 @@ dependencies = [
"futures-util", "futures-util",
"hyper 1.5.2", "hyper 1.5.2",
"hyper-util", "hyper-util",
"log",
"rand 0.9.0", "rand 0.9.0",
"slotmap", "slotmap",
"tokio", "tokio",
@ -2352,9 +2353,9 @@ dependencies = [
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.22" version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]] [[package]]
name = "loop9" name = "loop9"

View File

@ -12,6 +12,7 @@ futures = "0.3.31"
futures-util = "0.3.31" futures-util = "0.3.31"
hyper = { version = "1.5.2", features = ["full"] } hyper = { version = "1.5.2", features = ["full"] }
hyper-util = { version = "0.1.10", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] }
log = "0.4.25"
rand = "0.9.0" rand = "0.9.0"
slotmap = "1.0.7" slotmap = "1.0.7"
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }

View File

@ -5,10 +5,11 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use futures_util::stream::StreamExt; use futures_util::stream::StreamExt;
use hyper::StatusCode;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use rand::{rng, Rng}; use rand::{rng, Rng};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use tokio::sync::{mpsc, oneshot};
mod trie; mod trie;
@ -17,9 +18,8 @@ use crate::trie::Trie;
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR"; const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>; type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
#[derive(Clone)]
pub struct ContentAware { pub struct ContentAware {
trie: Arc<Mutex<Trie>>, trie: Trie,
} }
impl Default for ContentAware { impl Default for ContentAware {
@ -30,14 +30,14 @@ impl Default for ContentAware {
impl ContentAware { impl ContentAware {
pub fn new() -> Self { pub fn new() -> Self {
let trie = Arc::new(Mutex::new(Trie::new())); let trie = Trie::new();
Self { trie } Self { trie }
} }
} }
impl LoadBalancer for ContentAware { impl LoadBalancer for ContentAware {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize { fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
let mut trie = self.trie.lock().unwrap(); let trie = &mut self.trie;
let (start, stop) = trie.insert(key); let (start, stop) = trie.insert(key);
let n = trie.count(); let n = trie.count();
eprintln!( eprintln!(
@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
} }
} }
#[derive(Clone)]
pub struct RoundRobin { pub struct RoundRobin {
current: Arc<AtomicUsize>, current: AtomicUsize,
} }
impl Default for RoundRobin { impl Default for RoundRobin {
@ -73,7 +72,7 @@ impl Default for RoundRobin {
impl RoundRobin { impl RoundRobin {
pub fn new() -> Self { pub fn new() -> Self {
let current = Arc::new(AtomicUsize::new(0)); let current = AtomicUsize::new(0);
Self { current } Self { current }
} }
} }
@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
} }
} }
#[derive(Clone)]
pub struct OverloadHandler<T: LoadBalancer> { pub struct OverloadHandler<T: LoadBalancer> {
client: Client,
load_balancer: T, load_balancer: T,
backends: Arc<Vec<String>>, backends: Vec<String>,
inqueue: Arc<Vec<AtomicUsize>>, inqueue: Vec<AtomicUsize>,
inflight: Arc<Vec<AtomicUsize>>, inflight: Vec<AtomicUsize>,
factor: f32, factor: f32,
rx: Rcv,
} }
impl<T: LoadBalancer> OverloadHandler<T> { impl<T: LoadBalancer> OverloadHandler<T> {
pub fn new(load_balancer: T, backends: Vec<String>) -> Self { pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
.build(HttpConnector::new()); let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let factor: f32 = std::env::var(FACTOR_KEY) let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string()) .unwrap_or("1.5".to_string())
.parse() .parse()
.unwrap_or(1.5); .unwrap_or(1.5);
let backends = Arc::new(backends);
Self { Self {
load_balancer, load_balancer,
backends, backends,
client,
factor, factor,
inflight, inflight,
inqueue, inqueue,
rx,
} }
} }
fn next(&mut self, key: &[u8]) -> usize { fn next(&mut self, key: &[u8]) -> String {
// Get the backend URL // Get the backend URL
let index = self.load_balancer.next(key, self.backends.len()); let index = self.load_balancer.next(key, self.backends.len());
let n = self.backends.len(); let n = self.backends.len();
@ -138,7 +133,45 @@ impl<T: LoadBalancer> OverloadHandler<T> {
inflight = self.inflight[index].load(Ordering::Relaxed); inflight = self.inflight[index].load(Ordering::Relaxed);
inqueue = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed);
} }
index let backend = &self.backends[index];
self.inflight[index].fetch_add(1, Ordering::Relaxed);
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
backend.to_string()
}
pub async fn run(&mut self) {
while let Some(msg) = self.rx.recv().await {
eprintln!("Msg {msg:?}");
match msg {
Msg::Next(key, sx) => {
let backend: String = self.next(&key);
eprintln!("Sending back backend {backend}");
if let Err(err) = sx.send(backend) {
eprintln!("Cannot send back result: {err}");
}
}
Msg::Dequeue(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::Deflight(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::AddBackend(backend) => {
self.backends.push(backend);
self.backends.sort();
}
Msg::RemoveBackend(backend) => {
self.backends.retain(|b| *b == backend);
self.backends.sort();
}
}
}
} }
} }
@ -146,21 +179,71 @@ pub trait LoadBalancer {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize; fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
} }
pub async fn handler<T: LoadBalancer>( #[derive(Debug)]
State(mut state): State<OverloadHandler<T>>, pub enum Msg {
req: Request, Next(Vec<u8>, oneshot::Sender<String>),
) -> Response<Body> { Dequeue(String),
// Get the next backend index Deflight(String),
let limit = 1024 * 1024; AddBackend(String),
let (parts, body) = req.into_parts(); RemoveBackend(String),
// TODO }
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
let index = state.next(&bytes);
let backend = &state.backends[index];
state.inflight[index].fetch_add(1, Ordering::Relaxed);
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
let body: Body = bytes.into(); type Snd = mpsc::Sender<Msg>;
type Rcv = mpsc::Receiver<Msg>;
#[derive(Clone)]
pub struct Communicator {
sender: Snd,
client: Client,
}
impl Communicator {
pub fn new(sender: Snd) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
Self { sender, client }
}
async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Dequeue(backend)).await
}
async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Deflight(backend)).await
}
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
let (sx, rx) = oneshot::channel();
self.sender.send(Msg::Next(key, sx)).await?;
let backend = rx.await.unwrap();
Ok(backend)
}
}
pub async fn handler(
State(state): State<Communicator>,
req: Request,
) -> Result<Response<Body>, StatusCode> {
// Get the next backend index
let (parts, body) = req.into_parts();
let mut response_stream = body.into_data_stream();
let event = response_stream.next().await;
let key = if let Some(Ok(event)) = &event {
event.to_vec()
} else {
vec![]
};
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
if let Some(event) = event{
yield event;
}
while let Some(raw_event) = response_stream.next().await {
yield raw_event;
}
};
let body = Body::from_stream(response_stream);
let mut req = Request::from_parts(parts, body); let mut req = Request::from_parts(parts, body);
let path = req.uri().path(); let path = req.uri().path();
let path_query = req let path_query = req
@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
.client .client
.request(req) .request(req)
.await .await
// TODO .map_err(|_| StatusCode::BAD_GATEWAY)?;
.unwrap();
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response = response.into_response(); let response = response.into_response();
let (parts, body) = response.into_parts(); let (parts, body) = response.into_parts();
let response_stream = body.into_data_stream(); let response_stream = body.into_data_stream();
@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
if start{ if start{
eprintln!("Not inqueue"); eprintln!("Not inqueue");
state.inqueue[index].fetch_sub(1, Ordering::Relaxed); state.dequeue(backend.to_string()).await.unwrap();
start = false; start = false;
} }
yield raw_event; yield raw_event;
} }
eprintln!("Not inflight"); eprintln!("Not inflight");
state.inflight[index].fetch_sub(1, Ordering::Relaxed); state.deflight(backend.to_string()).await.unwrap();
}; };
let body = Body::from_stream(response_stream); let body = Body::from_stream(response_stream);
Response::from_parts(parts, body) Ok(Response::from_parts(parts, body))
} }

View File

@ -2,44 +2,41 @@ use axum::{
routing::Router, routing::Router,
routing::{get, post}, routing::{get, post},
}; };
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin}; use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// List of backend servers // List of backend servers
let backends = vec![ let backends = vec![
"http://localhost:8000".to_string(), "http://localhost:8000".to_string(),
"http://localhost:8001".to_string(), // "http://localhost:8001".to_string(),
"http://localhost:8002".to_string(), // "http://localhost:8002".to_string(),
"http://localhost:8003".to_string(), // "http://localhost:8003".to_string(),
]; ];
// Create a new instance of the RoundRobinRouter // Create a new instance of the RoundRobinRouter
println!("Using Content aware");
// Create the Axum router
let (sx, rx) = tokio::sync::mpsc::channel(100);
let communicator = Communicator::new(sx);
tokio::task::spawn(async move {
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin"); println!("Using round robin");
let lb = RoundRobin::new(); let lb = RoundRobin::new();
// Create the Axum router let mut router = OverloadHandler::new(lb, backends, rx);
let router = OverloadHandler::new(lb, backends); router.run().await;
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
} else { } else {
println!("Using Content aware");
let lb = ContentAware::new(); let lb = ContentAware::new();
// Create the Axum router let mut router = OverloadHandler::new(lb, backends, rx);
let router = OverloadHandler::new(lb, backends); router.run().await;
};
});
let app = Router::new() let app = Router::new()
.route("/{*key}", get(handler)) .route("/{*key}", get(handler))
.route("/{*key}", post(handler)) .route("/{*key}", post(handler))
.with_state(router); .with_state(communicator);
// run it // run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
@ -47,5 +44,4 @@ async fn main() {
.unwrap(); .unwrap();
println!("listening on {}", listener.local_addr().unwrap()); println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();
};
} }