diff --git a/Cargo.lock b/Cargo.lock index 9b741b1f..80d5f555 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2256,6 +2256,7 @@ dependencies = [ "hyper-util", "log", "rand 0.9.0", + "serde", "slotmap", "tokio", ] diff --git a/kvrouter/Cargo.toml b/kvrouter/Cargo.toml index 13944793..ed4ea666 100644 --- a/kvrouter/Cargo.toml +++ b/kvrouter/Cargo.toml @@ -14,6 +14,6 @@ hyper = { version = "1.5.2", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] } log = "0.4.25" rand = "0.9.0" -serde = "1" +serde = { version = "1", features = ["derive"] } slotmap = "1.0.7" tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index e2830abc..e6aa183d 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -10,11 +10,8 @@ use hyper::StatusCode; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use rand::{rng, Rng}; use serde::Deserialize; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; -use tokio::sync::{mpsc, oneshot, RwLock}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::{mpsc, oneshot}; mod trie; @@ -90,7 +87,7 @@ impl LoadBalancer for RoundRobin { pub struct OverloadHandler { load_balancer: T, - backends: Arc>>, + backends: Vec, inqueue: Vec, inflight: Vec, factor: f32, @@ -98,19 +95,9 @@ pub struct OverloadHandler { } impl OverloadHandler { - pub async fn new(load_balancer: T, backends: Arc>>, rx: Rcv) -> Self { - let inflight = backends - .read() - .await - .iter() - .map(|_| AtomicUsize::new(0)) - .collect(); - let inqueue = backends - .read() - .await - .iter() - .map(|_| AtomicUsize::new(0)) - .collect(); + pub async fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { + let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); + let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let factor: f32 = std::env::var(FACTOR_KEY) .unwrap_or("1.5".to_string()) .parse() @@ -126,13 +113,12 @@ impl OverloadHandler { } async fn next(&mut self, key: &[u8]) -> Option { - let backends = self.backends.read().await; - if backends.is_empty() { + if self.backends.is_empty() { return None; } // Get the backend URL - let index = self.load_balancer.next(key, backends.len()); - let n = backends.len(); + let index = self.load_balancer.next(key, self.backends.len()); + let n = self.backends.len(); let mut index = index % n; let mut inflight = self.inflight[index].load(Ordering::Relaxed); @@ -148,11 +134,11 @@ impl OverloadHandler { ); } index += 1; - index %= backends.len(); + index %= self.backends.len(); inflight = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed); } - let backend = &backends[index]; + let backend = &self.backends[index]; self.inflight[index].fetch_add(1, Ordering::Relaxed); self.inqueue[index].fetch_add(1, Ordering::Relaxed); Some(backend.to_string()) @@ -173,39 +159,27 @@ impl OverloadHandler { } } Msg::Dequeue(backend) => { - let index = self - .backends - .read() - .await - .iter() - .position(|b| b == &backend); + let index = self.backends.iter().position(|b| b == &backend); if let Some(index) = index { self.inqueue[index].fetch_sub(1, Ordering::Relaxed); } } Msg::Deflight(backend) => { - let index = self - .backends - .read() - .await - .iter() - .position(|b| b == &backend); + let index = self.backends.iter().position(|b| b == &backend); if let Some(index) = index { self.inflight[index].fetch_sub(1, Ordering::Relaxed); } } Msg::AddBackend(backend) => { - let mut backends = self.backends.write().await; - backends.push(backend); - backends.sort(); + self.backends.push(backend); + self.backends.sort(); } Msg::RemoveBackend(backend) => { - let mut backends = self.backends.write().await; - backends.retain(|b| *b == backend); - backends.sort(); + self.backends.retain(|b| *b == backend); + self.backends.sort(); } Msg::SetBackends(backends) => { - *self.backends.write().await = backends; + self.backends = backends; } } } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 6ff1e0b5..922dfa92 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use axum::{ routing::Router, routing::{get, post}, @@ -8,17 +6,16 @@ use hyper::StatusCode; use kvrouter::{ handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin, }; -use tokio::sync::RwLock; #[tokio::main] async fn main() { // List of backend servers - let backends = Arc::new(RwLock::new(vec![ + let backends = vec![ // "http://localhost:8000".to_string(), // "http://localhost:8001".to_string(), // "http://localhost:8002".to_string(), // "http://localhost:8003".to_string(), - ])); + ]; // Create a new instance of the RoundRobinRouter