diff --git a/Dockerfile_router b/Dockerfile_router new file mode 100644 index 00000000..ab839645 --- /dev/null +++ b/Dockerfile_router @@ -0,0 +1,48 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef AS planner +COPY Cargo.lock Cargo.lock +COPY Cargo.toml Cargo.toml +# COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY kvrouter kvrouter +COPY backends backends +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json --bin kvrouter + +ARG GIT_SHA +ARG DOCKER_LABEL + +COPY Cargo.lock Cargo.lock +COPY Cargo.toml Cargo.toml +# COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY kvrouter kvrouter +COPY backends backends +COPY launcher launcher +RUN cargo build --profile release-opt --frozen --bin kvrouter + +# Text Generation Inference base image for router +FROM ubuntu:22.04 AS router + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates + +# Install router +COPY --from=builder /usr/src/target/release-opt/kvrouter /usr/local/bin/kvrouter + +ENTRYPOINT ["kvrouter"] +CMD ["--json-output"] diff --git a/kvrouter/Cargo.toml b/kvrouter/Cargo.toml index efb2aa2a..13944793 100644 --- a/kvrouter/Cargo.toml +++ b/kvrouter/Cargo.toml @@ -14,5 +14,6 @@ hyper = { version = "1.5.2", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] } log = "0.4.25" rand = "0.9.0" +serde = "1" slotmap = "1.0.7" tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } diff --git a/kvrouter/src/lib.rs b/kvrouter/src/lib.rs index 874ae192..e2830abc 100644 --- a/kvrouter/src/lib.rs +++ b/kvrouter/src/lib.rs @@ -3,13 +3,18 @@ use axum::{ extract::{Request, State}, http::uri::Uri, response::{IntoResponse, Response}, + Json, }; use futures_util::stream::StreamExt; use hyper::StatusCode; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use rand::{rng, Rng}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::{mpsc, oneshot}; +use serde::Deserialize; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::sync::{mpsc, oneshot, RwLock}; mod trie; @@ -85,7 +90,7 @@ impl LoadBalancer for RoundRobin { pub struct OverloadHandler { load_balancer: T, - backends: Vec, + backends: Arc>>, inqueue: Vec, inflight: Vec, factor: f32, @@ -93,9 +98,19 @@ pub struct OverloadHandler { } impl OverloadHandler { - pub fn new(load_balancer: T, backends: Vec, rx: Rcv) -> Self { - let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); - let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); + pub async fn new(load_balancer: T, backends: Arc>>, rx: Rcv) -> Self { + let inflight = backends + .read() + .await + .iter() + .map(|_| AtomicUsize::new(0)) + .collect(); + let inqueue = backends + .read() + .await + .iter() + .map(|_| AtomicUsize::new(0)) + .collect(); let factor: f32 = std::env::var(FACTOR_KEY) .unwrap_or("1.5".to_string()) .parse() @@ -110,10 +125,14 @@ impl OverloadHandler { } } - fn next(&mut self, key: &[u8]) -> String { + async fn next(&mut self, key: &[u8]) -> Option { + let backends = self.backends.read().await; + if backends.is_empty() { + return None; + } // Get the backend URL - let index = self.load_balancer.next(key, self.backends.len()); - let n = self.backends.len(); + let index = self.load_balancer.next(key, backends.len()); + let n = backends.len(); let mut index = index % n; let mut inflight = self.inflight[index].load(Ordering::Relaxed); @@ -129,14 +148,14 @@ impl OverloadHandler { ); } index += 1; - index %= self.backends.len(); + index %= backends.len(); inflight = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed); } - let backend = &self.backends[index]; + let backend = &backends[index]; self.inflight[index].fetch_add(1, Ordering::Relaxed); self.inqueue[index].fetch_add(1, Ordering::Relaxed); - backend.to_string() + Some(backend.to_string()) } pub async fn run(&mut self) { @@ -144,31 +163,49 @@ impl OverloadHandler { eprintln!("Msg {msg:?}"); match msg { Msg::Next(key, sx) => { - let backend: String = self.next(&key); + let Some(backend) = self.next(&key).await else { + drop(sx); + return; + }; eprintln!("Sending back backend {backend}"); if let Err(err) = sx.send(backend) { eprintln!("Cannot send back result: {err}"); } } Msg::Dequeue(backend) => { - let index = self.backends.iter().position(|b| b == &backend); + let index = self + .backends + .read() + .await + .iter() + .position(|b| b == &backend); if let Some(index) = index { self.inqueue[index].fetch_sub(1, Ordering::Relaxed); } } Msg::Deflight(backend) => { - let index = self.backends.iter().position(|b| b == &backend); + let index = self + .backends + .read() + .await + .iter() + .position(|b| b == &backend); if let Some(index) = index { self.inflight[index].fetch_sub(1, Ordering::Relaxed); } } Msg::AddBackend(backend) => { - self.backends.push(backend); - self.backends.sort(); + let mut backends = self.backends.write().await; + backends.push(backend); + backends.sort(); } Msg::RemoveBackend(backend) => { - self.backends.retain(|b| *b == backend); - self.backends.sort(); + let mut backends = self.backends.write().await; + backends.retain(|b| *b == backend); + backends.sort(); + } + Msg::SetBackends(backends) => { + *self.backends.write().await = backends; } } } @@ -186,6 +223,7 @@ pub enum Msg { Deflight(String), AddBackend(String), RemoveBackend(String), + SetBackends(Vec), } type Snd = mpsc::Sender; @@ -215,7 +253,9 @@ impl Communicator { async fn next(&self, key: Vec) -> Result> { let (sx, rx) = oneshot::channel(); self.sender.send(Msg::Next(key, sx)).await?; - let backend = rx.await.unwrap(); + let backend = rx + .await + .map_err(|_| mpsc::error::SendError(Msg::AddBackend("todo".to_string())))?; Ok(backend) } } @@ -284,3 +324,16 @@ pub async fn handler( Ok(Response::from_parts(parts, body)) } + +#[derive(Deserialize)] +pub struct SetBackends { + backends: Vec, +} + +pub async fn set_backends_handler( + State(state): State, + Json(SetBackends { backends }): Json, +) -> impl IntoResponse { + let _ = state.sender.send(Msg::SetBackends(backends)).await; + StatusCode::OK +} diff --git a/kvrouter/src/main.rs b/kvrouter/src/main.rs index b5081545..6ff1e0b5 100644 --- a/kvrouter/src/main.rs +++ b/kvrouter/src/main.rs @@ -1,18 +1,24 @@ +use std::sync::Arc; + use axum::{ routing::Router, routing::{get, post}, }; -use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin}; +use hyper::StatusCode; +use kvrouter::{ + handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin, +}; +use tokio::sync::RwLock; #[tokio::main] async fn main() { // List of backend servers - let backends = vec![ - "http://localhost:8000".to_string(), + let backends = Arc::new(RwLock::new(vec![ + // "http://localhost:8000".to_string(), // "http://localhost:8001".to_string(), // "http://localhost:8002".to_string(), // "http://localhost:8003".to_string(), - ]; + ])); // Create a new instance of the RoundRobinRouter @@ -25,23 +31,23 @@ async fn main() { if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { println!("Using round robin"); let lb = RoundRobin::new(); - let mut router = OverloadHandler::new(lb, backends, rx); + let mut router = OverloadHandler::new(lb, backends, rx).await; router.run().await; } else { let lb = ContentAware::new(); - let mut router = OverloadHandler::new(lb, backends, rx); + let mut router = OverloadHandler::new(lb, backends, rx).await; router.run().await; }; }); let app = Router::new() .route("/{*key}", get(handler)) .route("/{*key}", post(handler)) + .route("/_kvrouter/health", get(|| async { StatusCode::OK })) + .route("/_kvrouter/set-backends", post(set_backends_handler)) .with_state(communicator); // run it - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); println!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); }