Cleaner version.

This commit is contained in:
Nicolas Patry 2025-01-31 09:07:31 +01:00
parent 1932c5b9ed
commit 57fa04adfd
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
4 changed files with 158 additions and 79 deletions

5
Cargo.lock generated
View File

@ -2254,6 +2254,7 @@ dependencies = [
"futures-util",
"hyper 1.5.2",
"hyper-util",
"log",
"rand 0.9.0",
"slotmap",
"tokio",
@ -2352,9 +2353,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.22"
version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "loop9"

View File

@ -12,6 +12,7 @@ futures = "0.3.31"
futures-util = "0.3.31"
hyper = { version = "1.5.2", features = ["full"] }
hyper-util = { version = "0.1.10", features = ["full"] }
log = "0.4.25"
rand = "0.9.0"
slotmap = "1.0.7"
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }

View File

@ -5,10 +5,11 @@ use axum::{
response::{IntoResponse, Response},
};
use futures_util::stream::StreamExt;
use hyper::StatusCode;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use rand::{rng, Rng};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::{mpsc, oneshot};
mod trie;
@ -17,9 +18,8 @@ use crate::trie::Trie;
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
#[derive(Clone)]
pub struct ContentAware {
trie: Arc<Mutex<Trie>>,
trie: Trie,
}
impl Default for ContentAware {
@ -30,14 +30,14 @@ impl Default for ContentAware {
impl ContentAware {
pub fn new() -> Self {
let trie = Arc::new(Mutex::new(Trie::new()));
let trie = Trie::new();
Self { trie }
}
}
impl LoadBalancer for ContentAware {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
let mut trie = self.trie.lock().unwrap();
let trie = &mut self.trie;
let (start, stop) = trie.insert(key);
let n = trie.count();
eprintln!(
@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
}
}
#[derive(Clone)]
pub struct RoundRobin {
current: Arc<AtomicUsize>,
current: AtomicUsize,
}
impl Default for RoundRobin {
@ -73,7 +72,7 @@ impl Default for RoundRobin {
impl RoundRobin {
pub fn new() -> Self {
let current = Arc::new(AtomicUsize::new(0));
let current = AtomicUsize::new(0);
Self { current }
}
}
@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
}
}
#[derive(Clone)]
pub struct OverloadHandler<T: LoadBalancer> {
client: Client,
load_balancer: T,
backends: Arc<Vec<String>>,
inqueue: Arc<Vec<AtomicUsize>>,
inflight: Arc<Vec<AtomicUsize>>,
backends: Vec<String>,
inqueue: Vec<AtomicUsize>,
inflight: Vec<AtomicUsize>,
factor: f32,
rx: Rcv,
}
impl<T: LoadBalancer> OverloadHandler<T> {
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string())
.parse()
.unwrap_or(1.5);
let backends = Arc::new(backends);
Self {
load_balancer,
backends,
client,
factor,
inflight,
inqueue,
rx,
}
}
fn next(&mut self, key: &[u8]) -> usize {
fn next(&mut self, key: &[u8]) -> String {
// Get the backend URL
let index = self.load_balancer.next(key, self.backends.len());
let n = self.backends.len();
@ -138,7 +133,45 @@ impl<T: LoadBalancer> OverloadHandler<T> {
inflight = self.inflight[index].load(Ordering::Relaxed);
inqueue = self.inflight[index].load(Ordering::Relaxed);
}
index
let backend = &self.backends[index];
self.inflight[index].fetch_add(1, Ordering::Relaxed);
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
backend.to_string()
}
pub async fn run(&mut self) {
while let Some(msg) = self.rx.recv().await {
eprintln!("Msg {msg:?}");
match msg {
Msg::Next(key, sx) => {
let backend: String = self.next(&key);
eprintln!("Sending back backend {backend}");
if let Err(err) = sx.send(backend) {
eprintln!("Cannot send back result: {err}");
}
}
Msg::Dequeue(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::Deflight(backend) => {
let index = self.backends.iter().position(|b| b == &backend);
if let Some(index) = index {
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
}
}
Msg::AddBackend(backend) => {
self.backends.push(backend);
self.backends.sort();
}
Msg::RemoveBackend(backend) => {
self.backends.retain(|b| *b == backend);
self.backends.sort();
}
}
}
}
}
@ -146,21 +179,71 @@ pub trait LoadBalancer {
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
}
pub async fn handler<T: LoadBalancer>(
State(mut state): State<OverloadHandler<T>>,
req: Request,
) -> Response<Body> {
// Get the next backend index
let limit = 1024 * 1024;
let (parts, body) = req.into_parts();
// TODO
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
let index = state.next(&bytes);
let backend = &state.backends[index];
state.inflight[index].fetch_add(1, Ordering::Relaxed);
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
#[derive(Debug)]
pub enum Msg {
Next(Vec<u8>, oneshot::Sender<String>),
Dequeue(String),
Deflight(String),
AddBackend(String),
RemoveBackend(String),
}
let body: Body = bytes.into();
type Snd = mpsc::Sender<Msg>;
type Rcv = mpsc::Receiver<Msg>;
#[derive(Clone)]
pub struct Communicator {
sender: Snd,
client: Client,
}
impl Communicator {
pub fn new(sender: Snd) -> Self {
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
Self { sender, client }
}
async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Dequeue(backend)).await
}
async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
self.sender.send(Msg::Deflight(backend)).await
}
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
let (sx, rx) = oneshot::channel();
self.sender.send(Msg::Next(key, sx)).await?;
let backend = rx.await.unwrap();
Ok(backend)
}
}
pub async fn handler(
State(state): State<Communicator>,
req: Request,
) -> Result<Response<Body>, StatusCode> {
// Get the next backend index
let (parts, body) = req.into_parts();
let mut response_stream = body.into_data_stream();
let event = response_stream.next().await;
let key = if let Some(Ok(event)) = &event {
event.to_vec()
} else {
vec![]
};
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
if let Some(event) = event{
yield event;
}
while let Some(raw_event) = response_stream.next().await {
yield raw_event;
}
};
let body = Body::from_stream(response_stream);
let mut req = Request::from_parts(parts, body);
let path = req.uri().path();
let path_query = req
@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
.client
.request(req)
.await
// TODO
.unwrap();
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
.map_err(|_| StatusCode::BAD_GATEWAY)?;
let response = response.into_response();
let (parts, body) = response.into_parts();
let response_stream = body.into_data_stream();
@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
if start{
eprintln!("Not inqueue");
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
state.dequeue(backend.to_string()).await.unwrap();
start = false;
}
yield raw_event;
}
eprintln!("Not inflight");
state.inflight[index].fetch_sub(1, Ordering::Relaxed);
state.deflight(backend.to_string()).await.unwrap();
};
let body = Body::from_stream(response_stream);
Response::from_parts(parts, body)
Ok(Response::from_parts(parts, body))
}

View File

@ -2,50 +2,46 @@ use axum::{
routing::Router,
routing::{get, post},
};
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
#[tokio::main]
async fn main() {
// List of backend servers
let backends = vec![
"http://localhost:8000".to_string(),
"http://localhost:8001".to_string(),
"http://localhost:8002".to_string(),
"http://localhost:8003".to_string(),
// "http://localhost:8001".to_string(),
// "http://localhost:8002".to_string(),
// "http://localhost:8003".to_string(),
];
// Create a new instance of the RoundRobinRouter
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
} else {
println!("Using Content aware");
let lb = ContentAware::new();
// Create the Axum router
let router = OverloadHandler::new(lb, backends);
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(router);
println!("Using Content aware");
// Create the Axum router
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
};
let (sx, rx) = tokio::sync::mpsc::channel(100);
let communicator = Communicator::new(sx);
tokio::task::spawn(async move {
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
let mut router = OverloadHandler::new(lb, backends, rx);
router.run().await;
} else {
let lb = ContentAware::new();
let mut router = OverloadHandler::new(lb, backends, rx);
router.run().await;
};
});
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.with_state(communicator);
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}