text-generation-inference/kvrouter/src/main.rs

54 lines
1.7 KiB
Rust
Raw Normal View History

use std::sync::Arc;
2025-01-28 18:48:17 +00:00
use axum::{
routing::Router,
routing::{get, post},
};
use hyper::StatusCode;
use kvrouter::{
handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin,
};
use tokio::sync::RwLock;
2025-01-28 18:48:17 +00:00
#[tokio::main]
async fn main() {
// List of backend servers
let backends = Arc::new(RwLock::new(vec![
// "http://localhost:8000".to_string(),
2025-01-31 08:07:31 +00:00
// "http://localhost:8001".to_string(),
// "http://localhost:8002".to_string(),
// "http://localhost:8003".to_string(),
]));
2025-01-28 18:48:17 +00:00
// Create a new instance of the RoundRobinRouter
2025-01-31 08:07:31 +00:00
println!("Using Content aware");
// Create the Axum router
2025-01-28 18:48:17 +00:00
2025-01-31 08:07:31 +00:00
let (sx, rx) = tokio::sync::mpsc::channel(100);
let communicator = Communicator::new(sx);
tokio::task::spawn(async move {
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
println!("Using round robin");
let lb = RoundRobin::new();
let mut router = OverloadHandler::new(lb, backends, rx).await;
2025-01-31 08:07:31 +00:00
router.run().await;
} else {
let lb = ContentAware::new();
let mut router = OverloadHandler::new(lb, backends, rx).await;
2025-01-31 08:07:31 +00:00
router.run().await;
};
});
let app = Router::new()
.route("/{*key}", get(handler))
.route("/{*key}", post(handler))
.route("/_kvrouter/health", get(|| async { StatusCode::OK }))
.route("/_kvrouter/set-backends", post(set_backends_handler))
2025-01-31 08:07:31 +00:00
.with_state(communicator);
// run it
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
2025-01-31 08:07:31 +00:00
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
2025-01-28 18:48:17 +00:00
}