mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 02:42:05 +00:00
add Dockerfile for kvrouter
add router to set backends
This commit is contained in:
parent
50c8ebdef0
commit
40b2011b3a
48
Dockerfile_router
Normal file
48
Dockerfile_router
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
# COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY kvrouter kvrouter
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json --bin kvrouter
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
# COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY kvrouter kvrouter
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --profile release-opt --frozen --bin kvrouter
|
||||||
|
|
||||||
|
# Text Generation Inference base image for router
|
||||||
|
FROM ubuntu:22.04 AS router
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
ca-certificates
|
||||||
|
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/kvrouter /usr/local/bin/kvrouter
|
||||||
|
|
||||||
|
ENTRYPOINT ["kvrouter"]
|
||||||
|
CMD ["--json-output"]
|
@ -14,5 +14,6 @@ hyper = { version = "1.5.2", features = ["full"] }
|
|||||||
hyper-util = { version = "0.1.10", features = ["full"] }
|
hyper-util = { version = "0.1.10", features = ["full"] }
|
||||||
log = "0.4.25"
|
log = "0.4.25"
|
||||||
rand = "0.9.0"
|
rand = "0.9.0"
|
||||||
|
serde = "1"
|
||||||
slotmap = "1.0.7"
|
slotmap = "1.0.7"
|
||||||
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
||||||
|
@ -3,13 +3,18 @@ use axum::{
|
|||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::uri::Uri,
|
http::uri::Uri,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
};
|
};
|
||||||
use futures_util::stream::StreamExt;
|
use futures_util::stream::StreamExt;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||||
use rand::{rng, Rng};
|
use rand::{rng, Rng};
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use serde::Deserialize;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use std::sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||||
|
|
||||||
mod trie;
|
mod trie;
|
||||||
|
|
||||||
@ -85,7 +90,7 @@ impl LoadBalancer for RoundRobin {
|
|||||||
|
|
||||||
pub struct OverloadHandler<T: LoadBalancer> {
|
pub struct OverloadHandler<T: LoadBalancer> {
|
||||||
load_balancer: T,
|
load_balancer: T,
|
||||||
backends: Vec<String>,
|
backends: Arc<RwLock<Vec<String>>>,
|
||||||
inqueue: Vec<AtomicUsize>,
|
inqueue: Vec<AtomicUsize>,
|
||||||
inflight: Vec<AtomicUsize>,
|
inflight: Vec<AtomicUsize>,
|
||||||
factor: f32,
|
factor: f32,
|
||||||
@ -93,9 +98,19 @@ pub struct OverloadHandler<T: LoadBalancer> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||||
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
pub async fn new(load_balancer: T, backends: Arc<RwLock<Vec<String>>>, rx: Rcv) -> Self {
|
||||||
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
let inflight = backends
|
||||||
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.map(|_| AtomicUsize::new(0))
|
||||||
|
.collect();
|
||||||
|
let inqueue = backends
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.map(|_| AtomicUsize::new(0))
|
||||||
|
.collect();
|
||||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||||
.unwrap_or("1.5".to_string())
|
.unwrap_or("1.5".to_string())
|
||||||
.parse()
|
.parse()
|
||||||
@ -110,10 +125,14 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn next(&mut self, key: &[u8]) -> String {
|
async fn next(&mut self, key: &[u8]) -> Option<String> {
|
||||||
|
let backends = self.backends.read().await;
|
||||||
|
if backends.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
// Get the backend URL
|
// Get the backend URL
|
||||||
let index = self.load_balancer.next(key, self.backends.len());
|
let index = self.load_balancer.next(key, backends.len());
|
||||||
let n = self.backends.len();
|
let n = backends.len();
|
||||||
let mut index = index % n;
|
let mut index = index % n;
|
||||||
|
|
||||||
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
@ -129,14 +148,14 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
index += 1;
|
index += 1;
|
||||||
index %= self.backends.len();
|
index %= backends.len();
|
||||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
let backend = &self.backends[index];
|
let backend = &backends[index];
|
||||||
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||||
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||||
backend.to_string()
|
Some(backend.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(&mut self) {
|
pub async fn run(&mut self) {
|
||||||
@ -144,31 +163,49 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
eprintln!("Msg {msg:?}");
|
eprintln!("Msg {msg:?}");
|
||||||
match msg {
|
match msg {
|
||||||
Msg::Next(key, sx) => {
|
Msg::Next(key, sx) => {
|
||||||
let backend: String = self.next(&key);
|
let Some(backend) = self.next(&key).await else {
|
||||||
|
drop(sx);
|
||||||
|
return;
|
||||||
|
};
|
||||||
eprintln!("Sending back backend {backend}");
|
eprintln!("Sending back backend {backend}");
|
||||||
if let Err(err) = sx.send(backend) {
|
if let Err(err) = sx.send(backend) {
|
||||||
eprintln!("Cannot send back result: {err}");
|
eprintln!("Cannot send back result: {err}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::Dequeue(backend) => {
|
Msg::Dequeue(backend) => {
|
||||||
let index = self.backends.iter().position(|b| b == &backend);
|
let index = self
|
||||||
|
.backends
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.position(|b| b == &backend);
|
||||||
if let Some(index) = index {
|
if let Some(index) = index {
|
||||||
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::Deflight(backend) => {
|
Msg::Deflight(backend) => {
|
||||||
let index = self.backends.iter().position(|b| b == &backend);
|
let index = self
|
||||||
|
.backends
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.position(|b| b == &backend);
|
||||||
if let Some(index) = index {
|
if let Some(index) = index {
|
||||||
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::AddBackend(backend) => {
|
Msg::AddBackend(backend) => {
|
||||||
self.backends.push(backend);
|
let mut backends = self.backends.write().await;
|
||||||
self.backends.sort();
|
backends.push(backend);
|
||||||
|
backends.sort();
|
||||||
}
|
}
|
||||||
Msg::RemoveBackend(backend) => {
|
Msg::RemoveBackend(backend) => {
|
||||||
self.backends.retain(|b| *b == backend);
|
let mut backends = self.backends.write().await;
|
||||||
self.backends.sort();
|
backends.retain(|b| *b == backend);
|
||||||
|
backends.sort();
|
||||||
|
}
|
||||||
|
Msg::SetBackends(backends) => {
|
||||||
|
*self.backends.write().await = backends;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -186,6 +223,7 @@ pub enum Msg {
|
|||||||
Deflight(String),
|
Deflight(String),
|
||||||
AddBackend(String),
|
AddBackend(String),
|
||||||
RemoveBackend(String),
|
RemoveBackend(String),
|
||||||
|
SetBackends(Vec<String>),
|
||||||
}
|
}
|
||||||
|
|
||||||
type Snd = mpsc::Sender<Msg>;
|
type Snd = mpsc::Sender<Msg>;
|
||||||
@ -215,7 +253,9 @@ impl Communicator {
|
|||||||
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
|
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
|
||||||
let (sx, rx) = oneshot::channel();
|
let (sx, rx) = oneshot::channel();
|
||||||
self.sender.send(Msg::Next(key, sx)).await?;
|
self.sender.send(Msg::Next(key, sx)).await?;
|
||||||
let backend = rx.await.unwrap();
|
let backend = rx
|
||||||
|
.await
|
||||||
|
.map_err(|_| mpsc::error::SendError(Msg::AddBackend("todo".to_string())))?;
|
||||||
Ok(backend)
|
Ok(backend)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,3 +324,16 @@ pub async fn handler(
|
|||||||
|
|
||||||
Ok(Response::from_parts(parts, body))
|
Ok(Response::from_parts(parts, body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct SetBackends {
|
||||||
|
backends: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_backends_handler(
|
||||||
|
State(state): State<Communicator>,
|
||||||
|
Json(SetBackends { backends }): Json<SetBackends>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let _ = state.sender.send(Msg::SetBackends(backends)).await;
|
||||||
|
StatusCode::OK
|
||||||
|
}
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
routing::Router,
|
routing::Router,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
|
use hyper::StatusCode;
|
||||||
|
use kvrouter::{
|
||||||
|
handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin,
|
||||||
|
};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// List of backend servers
|
// List of backend servers
|
||||||
let backends = vec![
|
let backends = Arc::new(RwLock::new(vec![
|
||||||
"http://localhost:8000".to_string(),
|
// "http://localhost:8000".to_string(),
|
||||||
// "http://localhost:8001".to_string(),
|
// "http://localhost:8001".to_string(),
|
||||||
// "http://localhost:8002".to_string(),
|
// "http://localhost:8002".to_string(),
|
||||||
// "http://localhost:8003".to_string(),
|
// "http://localhost:8003".to_string(),
|
||||||
];
|
]));
|
||||||
|
|
||||||
// Create a new instance of the RoundRobinRouter
|
// Create a new instance of the RoundRobinRouter
|
||||||
|
|
||||||
@ -25,23 +31,23 @@ async fn main() {
|
|||||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||||
println!("Using round robin");
|
println!("Using round robin");
|
||||||
let lb = RoundRobin::new();
|
let lb = RoundRobin::new();
|
||||||
let mut router = OverloadHandler::new(lb, backends, rx);
|
let mut router = OverloadHandler::new(lb, backends, rx).await;
|
||||||
router.run().await;
|
router.run().await;
|
||||||
} else {
|
} else {
|
||||||
let lb = ContentAware::new();
|
let lb = ContentAware::new();
|
||||||
let mut router = OverloadHandler::new(lb, backends, rx);
|
let mut router = OverloadHandler::new(lb, backends, rx).await;
|
||||||
router.run().await;
|
router.run().await;
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/{*key}", get(handler))
|
.route("/{*key}", get(handler))
|
||||||
.route("/{*key}", post(handler))
|
.route("/{*key}", post(handler))
|
||||||
|
.route("/_kvrouter/health", get(|| async { StatusCode::OK }))
|
||||||
|
.route("/_kvrouter/set-backends", post(set_backends_handler))
|
||||||
.with_state(communicator);
|
.with_state(communicator);
|
||||||
|
|
||||||
// run it
|
// run it
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
println!("listening on {}", listener.local_addr().unwrap());
|
println!("listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user