diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index e5bbbd13d..33b436390 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -172,17 +172,19 @@ impl OverloadHandler { } Msg::AddBackend(backend) => { match self.backends.binary_search(&backend) { - Ok(pos) => {} // element already in vector @ `pos` + Ok(pos) => { + log::warn!("Backend {backend} already exists at pos {pos}"); + } // element already in vector @ `pos` Err(pos) => { - self.backends.insert(pos, new_elem) - self.inflight.insert(pos, AtomicUsize::new(0)) - self.inqueue.insert(pos, AtomicUsize::new(0)) + self.backends.insert(pos, backend); + self.inflight.insert(pos, AtomicUsize::new(0)); + self.inqueue.insert(pos, AtomicUsize::new(0)); } } } Msg::RemoveBackend(backend) => { let position = self.backends.iter().position(|b| *b == backend); - if let Some(p) = position{ + if let Some(p) = position { self.backends.remove(p); self.inflight.remove(p); self.inqueue.remove(p); @@ -190,19 +192,27 @@ impl OverloadHandler { } Msg::SetBackends(mut new_backends) => { new_backends.sort(); - let (new_backends, new_inflight, new_inqueue): (Vec<_>, Vec<_>, Vec<_>) = self.new_backends.iter().enumerate().map(|(ni, nb)| { - if let Some(i) = backends.iter().position(|b| b == nb){ - let inflight = self.inflight[i]; - let inqueue = self.inqueue[i]; - (nb, inflight, inqueue) - }else{ - (nb, AtomicUsize::new(0), AtomicUsize::new(0)) - } - }).collect(); + let (new_backends, (new_inflight, new_inqueue)): ( + Vec, + (Vec<_>, Vec<_>), + ) = new_backends + .into_iter() + .map(|nb| { + if let Some(i) = self.backends.iter().position(|b| *b == nb) { + let inflight = + AtomicUsize::new(self.inflight[i].load(Ordering::Relaxed)); + let inqueue = + AtomicUsize::new(self.inqueue[i].load(Ordering::Relaxed)); + (nb, (inflight, inqueue)) + } else { + (nb, (AtomicUsize::new(0), AtomicUsize::new(0))) + } + }) + .unzip(); self.backends = new_backends; - self.inflight = inflight; - self.inqueue = inqueue; + self.inflight = new_inflight; + self.inqueue = new_inqueue; } } } diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index ec1d290fa..12ccb3c4c 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -8,7 +8,7 @@ use kvrouter::{ }; #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box> { // List of backend servers let backends = vec![ // "http://localhost:8000".to_string(), @@ -24,8 +24,10 @@ async fn main() { let (sx, rx) = tokio::sync::mpsc::channel(100); let communicator = Communicator::new(sx); - let host = std::env::var("TGI_KVROUTER_HOST").unwrap_or("127.0.0.1"); - let port : u16= std::env::var("TGI_KVROUTER_PORT").unwrap_or("3000").parse()?; + let host = std::env::var("TGI_KVROUTER_HOST").unwrap_or("127.0.0.1".to_string()); + let port: u16 = std::env::var("TGI_KVROUTER_PORT") + .unwrap_or("3000".to_string()) + .parse()?; tokio::task::spawn(async move { if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { println!("Using round robin"); @@ -46,7 +48,8 @@ async fn main() { .with_state(communicator); // run it - let listener = tokio::net::TcpListener::bind((HOST, PORT)).await.unwrap(); + let listener = tokio::net::TcpListener::bind((host, port)).await.unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); + Ok(()) }