add Dockerfile for kvrouter

add router to set backends
This commit is contained in:
Corentin REGAL 2025-03-31 16:14:12 +02:00
parent 50c8ebdef0
commit 40b2011b3a
4 changed files with 137 additions and 29 deletions

48
Dockerfile_router Normal file
View File

@ -0,0 +1,48 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
# COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY kvrouter kvrouter
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json --bin kvrouter
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
# COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY kvrouter kvrouter
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen --bin kvrouter
# Text Generation Inference base image for router
FROM ubuntu:22.04 AS router
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
ca-certificates
# Install router
COPY --from=builder /usr/src/target/release-opt/kvrouter /usr/local/bin/kvrouter
ENTRYPOINT ["kvrouter"]
CMD ["--json-output"]

View File

@ -14,5 +14,6 @@ hyper = { version = "1.5.2", features = ["full"] }
hyper-util = { version = "0.1.10", features = ["full"] } hyper-util = { version = "0.1.10", features = ["full"] }
log = "0.4.25" log = "0.4.25"
rand = "0.9.0" rand = "0.9.0"
serde = "1"
slotmap = "1.0.7" slotmap = "1.0.7"
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }

View File

@ -3,13 +3,18 @@ use axum::{
extract::{Request, State}, extract::{Request, State},
http::uri::Uri, http::uri::Uri,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use futures_util::stream::StreamExt; use futures_util::stream::StreamExt;
use hyper::StatusCode; use hyper::StatusCode;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use rand::{rng, Rng}; use rand::{rng, Rng};
use std::sync::atomic::{AtomicUsize, Ordering}; use serde::Deserialize;
use tokio::sync::{mpsc, oneshot}; use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio::sync::{mpsc, oneshot, RwLock};
mod trie; mod trie;
@ -85,7 +90,7 @@ impl LoadBalancer for RoundRobin {
pub struct OverloadHandler<T: LoadBalancer> { pub struct OverloadHandler<T: LoadBalancer> {
load_balancer: T, load_balancer: T,
backends: Vec<String>, backends: Arc<RwLock<Vec<String>>>,
inqueue: Vec<AtomicUsize>, inqueue: Vec<AtomicUsize>,
inflight: Vec<AtomicUsize>, inflight: Vec<AtomicUsize>,
factor: f32, factor: f32,
@ -93,9 +98,19 @@ pub struct OverloadHandler<T: LoadBalancer> {
} }
impl<T: LoadBalancer> OverloadHandler<T> { impl<T: LoadBalancer> OverloadHandler<T> {
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self { pub async fn new(load_balancer: T, backends: Arc<RwLock<Vec<String>>>, rx: Rcv) -> Self {
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect(); let inflight = backends
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect(); .read()
.await
.iter()
.map(|_| AtomicUsize::new(0))
.collect();
let inqueue = backends
.read()
.await
.iter()
.map(|_| AtomicUsize::new(0))
.collect();
let factor: f32 = std::env::var(FACTOR_KEY) let factor: f32 = std::env::var(FACTOR_KEY)
.unwrap_or("1.5".to_string()) .unwrap_or("1.5".to_string())
.parse() .parse()
@ -110,10 +125,14 @@ impl<T: LoadBalancer> OverloadHandler<T> {
} }
} }
fn next(&mut self, key: &[u8]) -> String { async fn next(&mut self, key: &[u8]) -> Option<String> {
let backends = self.backends.read().await;
if backends.is_empty() {
return None;
}
// Get the backend URL // Get the backend URL
let index = self.load_balancer.next(key, self.backends.len()); let index = self.load_balancer.next(key, backends.len());
let n = self.backends.len(); let n = backends.len();
let mut index = index % n; let mut index = index % n;
let mut inflight = self.inflight[index].load(Ordering::Relaxed); let mut inflight = self.inflight[index].load(Ordering::Relaxed);
@ -129,14 +148,14 @@ impl<T: LoadBalancer> OverloadHandler<T> {
); );
} }
index += 1; index += 1;
index %= self.backends.len(); index %= backends.len();
inflight = self.inflight[index].load(Ordering::Relaxed); inflight = self.inflight[index].load(Ordering::Relaxed);
inqueue = self.inflight[index].load(Ordering::Relaxed); inqueue = self.inflight[index].load(Ordering::Relaxed);
} }
let backend = &self.backends[index]; let backend = &backends[index];
self.inflight[index].fetch_add(1, Ordering::Relaxed); self.inflight[index].fetch_add(1, Ordering::Relaxed);
self.inqueue[index].fetch_add(1, Ordering::Relaxed); self.inqueue[index].fetch_add(1, Ordering::Relaxed);
backend.to_string() Some(backend.to_string())
} }
pub async fn run(&mut self) { pub async fn run(&mut self) {
@ -144,31 +163,49 @@ impl<T: LoadBalancer> OverloadHandler<T> {
eprintln!("Msg {msg:?}"); eprintln!("Msg {msg:?}");
match msg { match msg {
Msg::Next(key, sx) => { Msg::Next(key, sx) => {
let backend: String = self.next(&key); let Some(backend) = self.next(&key).await else {
drop(sx);
return;
};
eprintln!("Sending back backend {backend}"); eprintln!("Sending back backend {backend}");
if let Err(err) = sx.send(backend) { if let Err(err) = sx.send(backend) {
eprintln!("Cannot send back result: {err}"); eprintln!("Cannot send back result: {err}");
} }
} }
Msg::Dequeue(backend) => { Msg::Dequeue(backend) => {
let index = self.backends.iter().position(|b| b == &backend); let index = self
.backends
.read()
.await
.iter()
.position(|b| b == &backend);
if let Some(index) = index { if let Some(index) = index {
self.inqueue[index].fetch_sub(1, Ordering::Relaxed); self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
} }
} }
Msg::Deflight(backend) => { Msg::Deflight(backend) => {
let index = self.backends.iter().position(|b| b == &backend); let index = self
.backends
.read()
.await
.iter()
.position(|b| b == &backend);
if let Some(index) = index { if let Some(index) = index {
self.inflight[index].fetch_sub(1, Ordering::Relaxed); self.inflight[index].fetch_sub(1, Ordering::Relaxed);
} }
} }
Msg::AddBackend(backend) => { Msg::AddBackend(backend) => {
self.backends.push(backend); let mut backends = self.backends.write().await;
self.backends.sort(); backends.push(backend);
backends.sort();
} }
Msg::RemoveBackend(backend) => { Msg::RemoveBackend(backend) => {
self.backends.retain(|b| *b == backend); let mut backends = self.backends.write().await;
self.backends.sort(); backends.retain(|b| *b == backend);
backends.sort();
}
Msg::SetBackends(backends) => {
*self.backends.write().await = backends;
} }
} }
} }
@ -186,6 +223,7 @@ pub enum Msg {
Deflight(String), Deflight(String),
AddBackend(String), AddBackend(String),
RemoveBackend(String), RemoveBackend(String),
SetBackends(Vec<String>),
} }
type Snd = mpsc::Sender<Msg>; type Snd = mpsc::Sender<Msg>;
@ -215,7 +253,9 @@ impl Communicator {
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> { async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
let (sx, rx) = oneshot::channel(); let (sx, rx) = oneshot::channel();
self.sender.send(Msg::Next(key, sx)).await?; self.sender.send(Msg::Next(key, sx)).await?;
let backend = rx.await.unwrap(); let backend = rx
.await
.map_err(|_| mpsc::error::SendError(Msg::AddBackend("todo".to_string())))?;
Ok(backend) Ok(backend)
} }
} }
@ -284,3 +324,16 @@ pub async fn handler(
Ok(Response::from_parts(parts, body)) Ok(Response::from_parts(parts, body))
} }
#[derive(Deserialize)]
pub struct SetBackends {
backends: Vec<String>,
}
pub async fn set_backends_handler(
State(state): State<Communicator>,
Json(SetBackends { backends }): Json<SetBackends>,
) -> impl IntoResponse {
let _ = state.sender.send(Msg::SetBackends(backends)).await;
StatusCode::OK
}

View File

@ -1,18 +1,24 @@
use std::sync::Arc;
use axum::{ use axum::{
routing::Router, routing::Router,
routing::{get, post}, routing::{get, post},
}; };
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin}; use hyper::StatusCode;
use kvrouter::{
handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin,
};
use tokio::sync::RwLock;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// List of backend servers // List of backend servers
let backends = vec![ let backends = Arc::new(RwLock::new(vec![
"http://localhost:8000".to_string(), // "http://localhost:8000".to_string(),
// "http://localhost:8001".to_string(), // "http://localhost:8001".to_string(),
// "http://localhost:8002".to_string(), // "http://localhost:8002".to_string(),
// "http://localhost:8003".to_string(), // "http://localhost:8003".to_string(),
]; ]));
// Create a new instance of the RoundRobinRouter // Create a new instance of the RoundRobinRouter
@ -25,23 +31,23 @@ async fn main() {
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" { if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin"); println!("Using round robin");
let lb = RoundRobin::new(); let lb = RoundRobin::new();
let mut router = OverloadHandler::new(lb, backends, rx); let mut router = OverloadHandler::new(lb, backends, rx).await;
router.run().await; router.run().await;
} else { } else {
let lb = ContentAware::new(); let lb = ContentAware::new();
let mut router = OverloadHandler::new(lb, backends, rx); let mut router = OverloadHandler::new(lb, backends, rx).await;
router.run().await; router.run().await;
}; };
}); });
let app = Router::new() let app = Router::new()
.route("/{*key}", get(handler)) .route("/{*key}", get(handler))
.route("/{*key}", post(handler)) .route("/{*key}", post(handler))
.route("/_kvrouter/health", get(|| async { StatusCode::OK }))
.route("/_kvrouter/set-backends", post(set_backends_handler))
.with_state(communicator); .with_state(communicator);
// run it // run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap()); println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();
} }