diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index e6aa183d..e5bbbd13 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -95,7 +95,7 @@ pub struct OverloadHandler { } impl OverloadHandler { - pub async fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { + pub fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let factor: f32 = std::env::var(FACTOR_KEY) @@ -136,7 +136,7 @@ impl OverloadHandler { index += 1; index %= self.backends.len(); inflight = self.inflight[index].load(Ordering::Relaxed); - inqueue = self.inflight[index].load(Ordering::Relaxed); + inqueue = self.inqueue[index].load(Ordering::Relaxed); } let backend = &self.backends[index]; self.inflight[index].fetch_add(1, Ordering::Relaxed); @@ -171,15 +171,38 @@ impl OverloadHandler { } } Msg::AddBackend(backend) => { - self.backends.push(backend); - self.backends.sort(); + match self.backends.binary_search(&backend) { + Ok(pos) => {} // element already in vector @ `pos` + Err(pos) => { + self.backends.insert(pos, new_elem) + self.inflight.insert(pos, AtomicUsize::new(0)) + self.inqueue.insert(pos, AtomicUsize::new(0)) + } + } } Msg::RemoveBackend(backend) => { - self.backends.retain(|b| *b == backend); - self.backends.sort(); + let position = self.backends.iter().position(|b| *b == backend); + if let Some(p) = position{ + self.backends.remove(p); + self.inflight.remove(p); + self.inqueue.remove(p); + } } - Msg::SetBackends(backends) => { - self.backends = backends; + Msg::SetBackends(mut new_backends) => { + new_backends.sort(); + let (new_backends, new_inflight, new_inqueue): (Vec<_>, Vec<_>, Vec<_>) = self.new_backends.iter().enumerate().map(|(ni, nb)| { + if let Some(i) = backends.iter().position(|b| b == nb){ + let inflight = self.inflight[i]; + let inqueue = self.inqueue[i]; + (nb, inflight, inqueue) + }else{ + (nb, AtomicUsize::new(0), AtomicUsize::new(0)) + } + }).collect(); + + self.backends = new_backends; + self.inflight = inflight; + self.inqueue = inqueue; } } } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index 922dfa92..ec1d290f 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -24,15 +24,17 @@ async fn main() { let (sx, rx) = tokio::sync::mpsc::channel(100); let communicator = Communicator::new(sx); + let host = std::env::var("TGI_KVROUTER_HOST").unwrap_or("127.0.0.1"); + let port : u16= std::env::var("TGI_KVROUTER_PORT").unwrap_or("3000").parse()?; tokio::task::spawn(async move { if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { println!("Using round robin"); let lb = RoundRobin::new(); - let mut router = OverloadHandler::new(lb, backends, rx).await; + let mut router = OverloadHandler::new(lb, backends, rx); router.run().await; } else { let lb = ContentAware::new(); - let mut router = OverloadHandler::new(lb, backends, rx).await; + let mut router = OverloadHandler::new(lb, backends, rx); router.run().await; }; }); @@ -44,7 +46,7 @@ async fn main() { .with_state(communicator); // run it - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + let listener = tokio::net::TcpListener::bind((HOST, PORT)).await.unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }