mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Cleaner version.
This commit is contained in:
parent
1932c5b9ed
commit
57fa04adfd
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -2254,6 +2254,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"hyper 1.5.2",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rand 0.9.0",
|
||||
"slotmap",
|
||||
"tokio",
|
||||
@ -2352,9 +2353,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
version = "0.4.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
|
||||
|
||||
[[package]]
|
||||
name = "loop9"
|
||||
|
@ -12,6 +12,7 @@ futures = "0.3.31"
|
||||
futures-util = "0.3.31"
|
||||
hyper = { version = "1.5.2", features = ["full"] }
|
||||
hyper-util = { version = "0.1.10", features = ["full"] }
|
||||
log = "0.4.25"
|
||||
rand = "0.9.0"
|
||||
slotmap = "1.0.7"
|
||||
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
||||
|
@ -5,10 +5,11 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::stream::StreamExt;
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||
use rand::{rng, Rng};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
mod trie;
|
||||
|
||||
@ -17,9 +18,8 @@ use crate::trie::Trie;
|
||||
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
|
||||
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ContentAware {
|
||||
trie: Arc<Mutex<Trie>>,
|
||||
trie: Trie,
|
||||
}
|
||||
|
||||
impl Default for ContentAware {
|
||||
@ -30,14 +30,14 @@ impl Default for ContentAware {
|
||||
|
||||
impl ContentAware {
|
||||
pub fn new() -> Self {
|
||||
let trie = Arc::new(Mutex::new(Trie::new()));
|
||||
let trie = Trie::new();
|
||||
Self { trie }
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancer for ContentAware {
|
||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
|
||||
let mut trie = self.trie.lock().unwrap();
|
||||
let trie = &mut self.trie;
|
||||
let (start, stop) = trie.insert(key);
|
||||
let n = trie.count();
|
||||
eprintln!(
|
||||
@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RoundRobin {
|
||||
current: Arc<AtomicUsize>,
|
||||
current: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Default for RoundRobin {
|
||||
@ -73,7 +72,7 @@ impl Default for RoundRobin {
|
||||
|
||||
impl RoundRobin {
|
||||
pub fn new() -> Self {
|
||||
let current = Arc::new(AtomicUsize::new(0));
|
||||
let current = AtomicUsize::new(0);
|
||||
Self { current }
|
||||
}
|
||||
}
|
||||
@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OverloadHandler<T: LoadBalancer> {
|
||||
client: Client,
|
||||
load_balancer: T,
|
||||
backends: Arc<Vec<String>>,
|
||||
inqueue: Arc<Vec<AtomicUsize>>,
|
||||
inflight: Arc<Vec<AtomicUsize>>,
|
||||
backends: Vec<String>,
|
||||
inqueue: Vec<AtomicUsize>,
|
||||
inflight: Vec<AtomicUsize>,
|
||||
factor: f32,
|
||||
rx: Rcv,
|
||||
}
|
||||
|
||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
|
||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
||||
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||
.unwrap_or("1.5".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1.5);
|
||||
let backends = Arc::new(backends);
|
||||
Self {
|
||||
load_balancer,
|
||||
backends,
|
||||
client,
|
||||
factor,
|
||||
inflight,
|
||||
inqueue,
|
||||
rx,
|
||||
}
|
||||
}
|
||||
|
||||
fn next(&mut self, key: &[u8]) -> usize {
|
||||
fn next(&mut self, key: &[u8]) -> String {
|
||||
// Get the backend URL
|
||||
let index = self.load_balancer.next(key, self.backends.len());
|
||||
let n = self.backends.len();
|
||||
@ -138,7 +133,45 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||
}
|
||||
index
|
||||
let backend = &self.backends[index];
|
||||
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||
backend.to_string()
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) {
|
||||
while let Some(msg) = self.rx.recv().await {
|
||||
eprintln!("Msg {msg:?}");
|
||||
match msg {
|
||||
Msg::Next(key, sx) => {
|
||||
let backend: String = self.next(&key);
|
||||
eprintln!("Sending back backend {backend}");
|
||||
if let Err(err) = sx.send(backend) {
|
||||
eprintln!("Cannot send back result: {err}");
|
||||
}
|
||||
}
|
||||
Msg::Dequeue(backend) => {
|
||||
let index = self.backends.iter().position(|b| b == &backend);
|
||||
if let Some(index) = index {
|
||||
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
Msg::Deflight(backend) => {
|
||||
let index = self.backends.iter().position(|b| b == &backend);
|
||||
if let Some(index) = index {
|
||||
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
Msg::AddBackend(backend) => {
|
||||
self.backends.push(backend);
|
||||
self.backends.sort();
|
||||
}
|
||||
Msg::RemoveBackend(backend) => {
|
||||
self.backends.retain(|b| *b == backend);
|
||||
self.backends.sort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,21 +179,71 @@ pub trait LoadBalancer {
|
||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
|
||||
}
|
||||
|
||||
pub async fn handler<T: LoadBalancer>(
|
||||
State(mut state): State<OverloadHandler<T>>,
|
||||
req: Request,
|
||||
) -> Response<Body> {
|
||||
// Get the next backend index
|
||||
let limit = 1024 * 1024;
|
||||
let (parts, body) = req.into_parts();
|
||||
// TODO
|
||||
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
|
||||
let index = state.next(&bytes);
|
||||
let backend = &state.backends[index];
|
||||
state.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||
#[derive(Debug)]
|
||||
pub enum Msg {
|
||||
Next(Vec<u8>, oneshot::Sender<String>),
|
||||
Dequeue(String),
|
||||
Deflight(String),
|
||||
AddBackend(String),
|
||||
RemoveBackend(String),
|
||||
}
|
||||
|
||||
let body: Body = bytes.into();
|
||||
type Snd = mpsc::Sender<Msg>;
|
||||
type Rcv = mpsc::Receiver<Msg>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Communicator {
|
||||
sender: Snd,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl Communicator {
|
||||
pub fn new(sender: Snd) -> Self {
|
||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
Self { sender, client }
|
||||
}
|
||||
|
||||
async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
|
||||
self.sender.send(Msg::Dequeue(backend)).await
|
||||
}
|
||||
|
||||
async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
|
||||
self.sender.send(Msg::Deflight(backend)).await
|
||||
}
|
||||
|
||||
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
|
||||
let (sx, rx) = oneshot::channel();
|
||||
self.sender.send(Msg::Next(key, sx)).await?;
|
||||
let backend = rx.await.unwrap();
|
||||
Ok(backend)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handler(
|
||||
State(state): State<Communicator>,
|
||||
req: Request,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
// Get the next backend index
|
||||
let (parts, body) = req.into_parts();
|
||||
let mut response_stream = body.into_data_stream();
|
||||
let event = response_stream.next().await;
|
||||
let key = if let Some(Ok(event)) = &event {
|
||||
event.to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
if let Some(event) = event{
|
||||
yield event;
|
||||
}
|
||||
while let Some(raw_event) = response_stream.next().await {
|
||||
yield raw_event;
|
||||
}
|
||||
};
|
||||
let body = Body::from_stream(response_stream);
|
||||
let mut req = Request::from_parts(parts, body);
|
||||
let path = req.uri().path();
|
||||
let path_query = req
|
||||
@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
|
||||
.client
|
||||
.request(req)
|
||||
.await
|
||||
// TODO
|
||||
.unwrap();
|
||||
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
||||
.map_err(|_| StatusCode::BAD_GATEWAY)?;
|
||||
let response = response.into_response();
|
||||
let (parts, body) = response.into_parts();
|
||||
let response_stream = body.into_data_stream();
|
||||
@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
|
||||
if start{
|
||||
eprintln!("Not inqueue");
|
||||
|
||||
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||
state.dequeue(backend.to_string()).await.unwrap();
|
||||
start = false;
|
||||
}
|
||||
yield raw_event;
|
||||
}
|
||||
eprintln!("Not inflight");
|
||||
state.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
||||
state.deflight(backend.to_string()).await.unwrap();
|
||||
};
|
||||
|
||||
let body = Body::from_stream(response_stream);
|
||||
|
||||
Response::from_parts(parts, body)
|
||||
Ok(Response::from_parts(parts, body))
|
||||
}
|
||||
|
@ -2,50 +2,46 @@ use axum::{
|
||||
routing::Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
|
||||
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// List of backend servers
|
||||
let backends = vec![
|
||||
"http://localhost:8000".to_string(),
|
||||
"http://localhost:8001".to_string(),
|
||||
"http://localhost:8002".to_string(),
|
||||
"http://localhost:8003".to_string(),
|
||||
// "http://localhost:8001".to_string(),
|
||||
// "http://localhost:8002".to_string(),
|
||||
// "http://localhost:8003".to_string(),
|
||||
];
|
||||
|
||||
// Create a new instance of the RoundRobinRouter
|
||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||
println!("Using round robin");
|
||||
let lb = RoundRobin::new();
|
||||
// Create the Axum router
|
||||
let router = OverloadHandler::new(lb, backends);
|
||||
let app = Router::new()
|
||||
.route("/{*key}", get(handler))
|
||||
.route("/{*key}", post(handler))
|
||||
.with_state(router);
|
||||
|
||||
// run it
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
} else {
|
||||
println!("Using Content aware");
|
||||
let lb = ContentAware::new();
|
||||
// Create the Axum router
|
||||
let router = OverloadHandler::new(lb, backends);
|
||||
let app = Router::new()
|
||||
.route("/{*key}", get(handler))
|
||||
.route("/{*key}", post(handler))
|
||||
.with_state(router);
|
||||
println!("Using Content aware");
|
||||
// Create the Axum router
|
||||
|
||||
// run it
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
};
|
||||
let (sx, rx) = tokio::sync::mpsc::channel(100);
|
||||
let communicator = Communicator::new(sx);
|
||||
tokio::task::spawn(async move {
|
||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||
println!("Using round robin");
|
||||
let lb = RoundRobin::new();
|
||||
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||
router.run().await;
|
||||
} else {
|
||||
let lb = ContentAware::new();
|
||||
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||
router.run().await;
|
||||
};
|
||||
});
|
||||
let app = Router::new()
|
||||
.route("/{*key}", get(handler))
|
||||
.route("/{*key}", post(handler))
|
||||
.with_state(communicator);
|
||||
|
||||
// run it
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user