mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Fixing add/remove/set backends.
This commit is contained in:
parent
5831ff6e69
commit
18cb4a4221
@ -95,7 +95,7 @@ pub struct OverloadHandler<T: LoadBalancer> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||||
pub async fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
||||||
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||||
@ -136,7 +136,7 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
index += 1;
|
index += 1;
|
||||||
index %= self.backends.len();
|
index %= self.backends.len();
|
||||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
inqueue = self.inqueue[index].load(Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
let backend = &self.backends[index];
|
let backend = &self.backends[index];
|
||||||
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||||
@ -171,15 +171,38 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::AddBackend(backend) => {
|
Msg::AddBackend(backend) => {
|
||||||
self.backends.push(backend);
|
match self.backends.binary_search(&backend) {
|
||||||
self.backends.sort();
|
Ok(pos) => {} // element already in vector @ `pos`
|
||||||
|
Err(pos) => {
|
||||||
|
self.backends.insert(pos, new_elem)
|
||||||
|
self.inflight.insert(pos, AtomicUsize::new(0))
|
||||||
|
self.inqueue.insert(pos, AtomicUsize::new(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Msg::RemoveBackend(backend) => {
|
Msg::RemoveBackend(backend) => {
|
||||||
self.backends.retain(|b| *b == backend);
|
let position = self.backends.iter().position(|b| *b == backend);
|
||||||
self.backends.sort();
|
if let Some(p) = position{
|
||||||
|
self.backends.remove(p);
|
||||||
|
self.inflight.remove(p);
|
||||||
|
self.inqueue.remove(p);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Msg::SetBackends(backends) => {
|
Msg::SetBackends(mut new_backends) => {
|
||||||
self.backends = backends;
|
new_backends.sort();
|
||||||
|
let (new_backends, new_inflight, new_inqueue): (Vec<_>, Vec<_>, Vec<_>) = self.new_backends.iter().enumerate().map(|(ni, nb)| {
|
||||||
|
if let Some(i) = backends.iter().position(|b| b == nb){
|
||||||
|
let inflight = self.inflight[i];
|
||||||
|
let inqueue = self.inqueue[i];
|
||||||
|
(nb, inflight, inqueue)
|
||||||
|
}else{
|
||||||
|
(nb, AtomicUsize::new(0), AtomicUsize::new(0))
|
||||||
|
}
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
self.backends = new_backends;
|
||||||
|
self.inflight = inflight;
|
||||||
|
self.inqueue = inqueue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,15 +24,17 @@ async fn main() {
|
|||||||
|
|
||||||
let (sx, rx) = tokio::sync::mpsc::channel(100);
|
let (sx, rx) = tokio::sync::mpsc::channel(100);
|
||||||
let communicator = Communicator::new(sx);
|
let communicator = Communicator::new(sx);
|
||||||
|
let host = std::env::var("TGI_KVROUTER_HOST").unwrap_or("127.0.0.1");
|
||||||
|
let port : u16= std::env::var("TGI_KVROUTER_PORT").unwrap_or("3000").parse()?;
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||||
println!("Using round robin");
|
println!("Using round robin");
|
||||||
let lb = RoundRobin::new();
|
let lb = RoundRobin::new();
|
||||||
let mut router = OverloadHandler::new(lb, backends, rx).await;
|
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||||
router.run().await;
|
router.run().await;
|
||||||
} else {
|
} else {
|
||||||
let lb = ContentAware::new();
|
let lb = ContentAware::new();
|
||||||
let mut router = OverloadHandler::new(lb, backends, rx).await;
|
let mut router = OverloadHandler::new(lb, backends, rx);
|
||||||
router.run().await;
|
router.run().await;
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@ -44,7 +46,7 @@ async fn main() {
|
|||||||
.with_state(communicator);
|
.with_state(communicator);
|
||||||
|
|
||||||
// run it
|
// run it
|
||||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
let listener = tokio::net::TcpListener::bind((HOST, PORT)).await.unwrap();
|
||||||
println!("listening on {}", listener.local_addr().unwrap());
|
println!("listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user