mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
Cleaner version.
This commit is contained in:
parent
1932c5b9ed
commit
57fa04adfd
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -2254,6 +2254,7 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"hyper 1.5.2",
|
"hyper 1.5.2",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
|
"log",
|
||||||
"rand 0.9.0",
|
"rand 0.9.0",
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"tokio",
|
"tokio",
|
||||||
@ -2352,9 +2353,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.22"
|
version = "0.4.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "loop9"
|
name = "loop9"
|
||||||
|
@ -12,6 +12,7 @@ futures = "0.3.31"
|
|||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
hyper = { version = "1.5.2", features = ["full"] }
|
hyper = { version = "1.5.2", features = ["full"] }
|
||||||
hyper-util = { version = "0.1.10", features = ["full"] }
|
hyper-util = { version = "0.1.10", features = ["full"] }
|
||||||
|
log = "0.4.25"
|
||||||
rand = "0.9.0"
|
rand = "0.9.0"
|
||||||
slotmap = "1.0.7"
|
slotmap = "1.0.7"
|
||||||
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
||||||
|
@ -5,10 +5,11 @@ use axum::{
|
|||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use futures_util::stream::StreamExt;
|
use futures_util::stream::StreamExt;
|
||||||
|
use hyper::StatusCode;
|
||||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||||
use rand::{rng, Rng};
|
use rand::{rng, Rng};
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
mod trie;
|
mod trie;
|
||||||
|
|
||||||
@ -17,9 +18,8 @@ use crate::trie::Trie;
|
|||||||
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
|
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
|
||||||
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ContentAware {
|
pub struct ContentAware {
|
||||||
trie: Arc<Mutex<Trie>>,
|
trie: Trie,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ContentAware {
|
impl Default for ContentAware {
|
||||||
@ -30,14 +30,14 @@ impl Default for ContentAware {
|
|||||||
|
|
||||||
impl ContentAware {
|
impl ContentAware {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let trie = Arc::new(Mutex::new(Trie::new()));
|
let trie = Trie::new();
|
||||||
Self { trie }
|
Self { trie }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadBalancer for ContentAware {
|
impl LoadBalancer for ContentAware {
|
||||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
|
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
|
||||||
let mut trie = self.trie.lock().unwrap();
|
let trie = &mut self.trie;
|
||||||
let (start, stop) = trie.insert(key);
|
let (start, stop) = trie.insert(key);
|
||||||
let n = trie.count();
|
let n = trie.count();
|
||||||
eprintln!(
|
eprintln!(
|
||||||
@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct RoundRobin {
|
pub struct RoundRobin {
|
||||||
current: Arc<AtomicUsize>,
|
current: AtomicUsize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RoundRobin {
|
impl Default for RoundRobin {
|
||||||
@ -73,7 +72,7 @@ impl Default for RoundRobin {
|
|||||||
|
|
||||||
impl RoundRobin {
|
impl RoundRobin {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let current = Arc::new(AtomicUsize::new(0));
|
let current = AtomicUsize::new(0);
|
||||||
Self { current }
|
Self { current }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OverloadHandler<T: LoadBalancer> {
|
pub struct OverloadHandler<T: LoadBalancer> {
|
||||||
client: Client,
|
|
||||||
load_balancer: T,
|
load_balancer: T,
|
||||||
backends: Arc<Vec<String>>,
|
backends: Vec<String>,
|
||||||
inqueue: Arc<Vec<AtomicUsize>>,
|
inqueue: Vec<AtomicUsize>,
|
||||||
inflight: Arc<Vec<AtomicUsize>>,
|
inflight: Vec<AtomicUsize>,
|
||||||
factor: f32,
|
factor: f32,
|
||||||
|
rx: Rcv,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||||
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
|
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
||||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
.build(HttpConnector::new());
|
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
|
||||||
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
|
||||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||||
.unwrap_or("1.5".to_string())
|
.unwrap_or("1.5".to_string())
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or(1.5);
|
.unwrap_or(1.5);
|
||||||
let backends = Arc::new(backends);
|
|
||||||
Self {
|
Self {
|
||||||
load_balancer,
|
load_balancer,
|
||||||
backends,
|
backends,
|
||||||
client,
|
|
||||||
factor,
|
factor,
|
||||||
inflight,
|
inflight,
|
||||||
inqueue,
|
inqueue,
|
||||||
|
rx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn next(&mut self, key: &[u8]) -> usize {
|
fn next(&mut self, key: &[u8]) -> String {
|
||||||
// Get the backend URL
|
// Get the backend URL
|
||||||
let index = self.load_balancer.next(key, self.backends.len());
|
let index = self.load_balancer.next(key, self.backends.len());
|
||||||
let n = self.backends.len();
|
let n = self.backends.len();
|
||||||
@ -138,7 +133,45 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
index
|
let backend = &self.backends[index];
|
||||||
|
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||||
|
backend.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&mut self) {
|
||||||
|
while let Some(msg) = self.rx.recv().await {
|
||||||
|
eprintln!("Msg {msg:?}");
|
||||||
|
match msg {
|
||||||
|
Msg::Next(key, sx) => {
|
||||||
|
let backend: String = self.next(&key);
|
||||||
|
eprintln!("Sending back backend {backend}");
|
||||||
|
if let Err(err) = sx.send(backend) {
|
||||||
|
eprintln!("Cannot send back result: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Msg::Dequeue(backend) => {
|
||||||
|
let index = self.backends.iter().position(|b| b == &backend);
|
||||||
|
if let Some(index) = index {
|
||||||
|
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Msg::Deflight(backend) => {
|
||||||
|
let index = self.backends.iter().position(|b| b == &backend);
|
||||||
|
if let Some(index) = index {
|
||||||
|
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Msg::AddBackend(backend) => {
|
||||||
|
self.backends.push(backend);
|
||||||
|
self.backends.sort();
|
||||||
|
}
|
||||||
|
Msg::RemoveBackend(backend) => {
|
||||||
|
self.backends.retain(|b| *b == backend);
|
||||||
|
self.backends.sort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,21 +179,71 @@ pub trait LoadBalancer {
|
|||||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
|
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handler<T: LoadBalancer>(
|
#[derive(Debug)]
|
||||||
State(mut state): State<OverloadHandler<T>>,
|
pub enum Msg {
|
||||||
req: Request,
|
Next(Vec<u8>, oneshot::Sender<String>),
|
||||||
) -> Response<Body> {
|
Dequeue(String),
|
||||||
// Get the next backend index
|
Deflight(String),
|
||||||
let limit = 1024 * 1024;
|
AddBackend(String),
|
||||||
let (parts, body) = req.into_parts();
|
RemoveBackend(String),
|
||||||
// TODO
|
}
|
||||||
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
|
|
||||||
let index = state.next(&bytes);
|
|
||||||
let backend = &state.backends[index];
|
|
||||||
state.inflight[index].fetch_add(1, Ordering::Relaxed);
|
|
||||||
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
|
||||||
|
|
||||||
let body: Body = bytes.into();
|
type Snd = mpsc::Sender<Msg>;
|
||||||
|
type Rcv = mpsc::Receiver<Msg>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Communicator {
|
||||||
|
sender: Snd,
|
||||||
|
client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Communicator {
|
||||||
|
pub fn new(sender: Snd) -> Self {
|
||||||
|
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||||
|
.build(HttpConnector::new());
|
||||||
|
Self { sender, client }
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
|
||||||
|
self.sender.send(Msg::Dequeue(backend)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
|
||||||
|
self.sender.send(Msg::Deflight(backend)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
|
||||||
|
let (sx, rx) = oneshot::channel();
|
||||||
|
self.sender.send(Msg::Next(key, sx)).await?;
|
||||||
|
let backend = rx.await.unwrap();
|
||||||
|
Ok(backend)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handler(
|
||||||
|
State(state): State<Communicator>,
|
||||||
|
req: Request,
|
||||||
|
) -> Result<Response<Body>, StatusCode> {
|
||||||
|
// Get the next backend index
|
||||||
|
let (parts, body) = req.into_parts();
|
||||||
|
let mut response_stream = body.into_data_stream();
|
||||||
|
let event = response_stream.next().await;
|
||||||
|
let key = if let Some(Ok(event)) = &event {
|
||||||
|
event.to_vec()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
||||||
|
let response_stream = async_stream::stream! {
|
||||||
|
let mut response_stream = Box::pin(response_stream);
|
||||||
|
if let Some(event) = event{
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
while let Some(raw_event) = response_stream.next().await {
|
||||||
|
yield raw_event;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let body = Body::from_stream(response_stream);
|
||||||
let mut req = Request::from_parts(parts, body);
|
let mut req = Request::from_parts(parts, body);
|
||||||
let path = req.uri().path();
|
let path = req.uri().path();
|
||||||
let path_query = req
|
let path_query = req
|
||||||
@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
|
|||||||
.client
|
.client
|
||||||
.request(req)
|
.request(req)
|
||||||
.await
|
.await
|
||||||
// TODO
|
.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
||||||
.unwrap();
|
|
||||||
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
|
||||||
let response = response.into_response();
|
let response = response.into_response();
|
||||||
let (parts, body) = response.into_parts();
|
let (parts, body) = response.into_parts();
|
||||||
let response_stream = body.into_data_stream();
|
let response_stream = body.into_data_stream();
|
||||||
@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
|
|||||||
if start{
|
if start{
|
||||||
eprintln!("Not inqueue");
|
eprintln!("Not inqueue");
|
||||||
|
|
||||||
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
state.dequeue(backend.to_string()).await.unwrap();
|
||||||
start = false;
|
start = false;
|
||||||
}
|
}
|
||||||
yield raw_event;
|
yield raw_event;
|
||||||
}
|
}
|
||||||
eprintln!("Not inflight");
|
eprintln!("Not inflight");
|
||||||
state.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
state.deflight(backend.to_string()).await.unwrap();
|
||||||
};
|
};
|
||||||
|
|
||||||
let body = Body::from_stream(response_stream);
|
let body = Body::from_stream(response_stream);
|
||||||
|
|
||||||
Response::from_parts(parts, body)
|
Ok(Response::from_parts(parts, body))
|
||||||
}
|
}
|
||||||
|
@ -2,44 +2,41 @@ use axum::{
|
|||||||
routing::Router,
|
routing::Router,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
|
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// List of backend servers
|
// List of backend servers
|
||||||
let backends = vec![
|
let backends = vec![
|
||||||
"http://localhost:8000".to_string(),
|
"http://localhost:8000".to_string(),
|
||||||
"http://localhost:8001".to_string(),
|
// "http://localhost:8001".to_string(),
|
||||||
"http://localhost:8002".to_string(),
|
// "http://localhost:8002".to_string(),
|
||||||
"http://localhost:8003".to_string(),
|
// "http://localhost:8003".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Create a new instance of the RoundRobinRouter
|
// Create a new instance of the RoundRobinRouter
|
||||||
|
|
||||||
|
println!("Using Content aware");
|
||||||
|
// Create the Axum router
|
||||||
|
|
||||||
|
let (sx, rx) = tokio::sync::mpsc::channel(100);
|
||||||
|
let communicator = Communicator::new(sx);
|
||||||
|
tokio::task::spawn(async move {
|
||||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||||
println!("Using round robin");
|
println!("Using round robin");
|
||||||
let lb = RoundRobin::new();
|
let lb = RoundRobin::new();
|
||||||
// Create the Axum router
|
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||||
let router = OverloadHandler::new(lb, backends);
|
router.run().await;
|
||||||
let app = Router::new()
|
|
||||||
.route("/{*key}", get(handler))
|
|
||||||
.route("/{*key}", post(handler))
|
|
||||||
.with_state(router);
|
|
||||||
|
|
||||||
// run it
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
println!("listening on {}", listener.local_addr().unwrap());
|
|
||||||
axum::serve(listener, app).await.unwrap();
|
|
||||||
} else {
|
} else {
|
||||||
println!("Using Content aware");
|
|
||||||
let lb = ContentAware::new();
|
let lb = ContentAware::new();
|
||||||
// Create the Axum router
|
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||||
let router = OverloadHandler::new(lb, backends);
|
router.run().await;
|
||||||
|
};
|
||||||
|
});
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/{*key}", get(handler))
|
.route("/{*key}", get(handler))
|
||||||
.route("/{*key}", post(handler))
|
.route("/{*key}", post(handler))
|
||||||
.with_state(router);
|
.with_state(communicator);
|
||||||
|
|
||||||
// run it
|
// run it
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||||
@ -47,5 +44,4 @@ async fn main() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
println!("listening on {}", listener.local_addr().unwrap());
|
println!("listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user