mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
remove useless rwlock
This commit is contained in:
parent
40b2011b3a
commit
5831ff6e69
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2256,6 +2256,7 @@ dependencies = [
|
|||||||
"hyper-util",
|
"hyper-util",
|
||||||
"log",
|
"log",
|
||||||
"rand 0.9.0",
|
"rand 0.9.0",
|
||||||
|
"serde",
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
@ -14,6 +14,6 @@ hyper = { version = "1.5.2", features = ["full"] }
|
|||||||
hyper-util = { version = "0.1.10", features = ["full"] }
|
hyper-util = { version = "0.1.10", features = ["full"] }
|
||||||
log = "0.4.25"
|
log = "0.4.25"
|
||||||
rand = "0.9.0"
|
rand = "0.9.0"
|
||||||
serde = "1"
|
serde = { version = "1", features = ["derive"] }
|
||||||
slotmap = "1.0.7"
|
slotmap = "1.0.7"
|
||||||
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
|
||||||
|
@ -10,11 +10,8 @@ use hyper::StatusCode;
|
|||||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||||
use rand::{rng, Rng};
|
use rand::{rng, Rng};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::{
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
atomic::{AtomicUsize, Ordering},
|
use tokio::sync::{mpsc, oneshot};
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
|
||||||
|
|
||||||
mod trie;
|
mod trie;
|
||||||
|
|
||||||
@ -90,7 +87,7 @@ impl LoadBalancer for RoundRobin {
|
|||||||
|
|
||||||
pub struct OverloadHandler<T: LoadBalancer> {
|
pub struct OverloadHandler<T: LoadBalancer> {
|
||||||
load_balancer: T,
|
load_balancer: T,
|
||||||
backends: Arc<RwLock<Vec<String>>>,
|
backends: Vec<String>,
|
||||||
inqueue: Vec<AtomicUsize>,
|
inqueue: Vec<AtomicUsize>,
|
||||||
inflight: Vec<AtomicUsize>,
|
inflight: Vec<AtomicUsize>,
|
||||||
factor: f32,
|
factor: f32,
|
||||||
@ -98,19 +95,9 @@ pub struct OverloadHandler<T: LoadBalancer> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||||
pub async fn new(load_balancer: T, backends: Arc<RwLock<Vec<String>>>, rx: Rcv) -> Self {
|
pub async fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
|
||||||
let inflight = backends
|
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
.read()
|
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
|
||||||
.await
|
|
||||||
.iter()
|
|
||||||
.map(|_| AtomicUsize::new(0))
|
|
||||||
.collect();
|
|
||||||
let inqueue = backends
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.iter()
|
|
||||||
.map(|_| AtomicUsize::new(0))
|
|
||||||
.collect();
|
|
||||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||||
.unwrap_or("1.5".to_string())
|
.unwrap_or("1.5".to_string())
|
||||||
.parse()
|
.parse()
|
||||||
@ -126,13 +113,12 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn next(&mut self, key: &[u8]) -> Option<String> {
|
async fn next(&mut self, key: &[u8]) -> Option<String> {
|
||||||
let backends = self.backends.read().await;
|
if self.backends.is_empty() {
|
||||||
if backends.is_empty() {
|
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
// Get the backend URL
|
// Get the backend URL
|
||||||
let index = self.load_balancer.next(key, backends.len());
|
let index = self.load_balancer.next(key, self.backends.len());
|
||||||
let n = backends.len();
|
let n = self.backends.len();
|
||||||
let mut index = index % n;
|
let mut index = index % n;
|
||||||
|
|
||||||
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
@ -148,11 +134,11 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
index += 1;
|
index += 1;
|
||||||
index %= backends.len();
|
index %= self.backends.len();
|
||||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
let backend = &backends[index];
|
let backend = &self.backends[index];
|
||||||
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
self.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||||
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||||
Some(backend.to_string())
|
Some(backend.to_string())
|
||||||
@ -173,39 +159,27 @@ impl<T: LoadBalancer> OverloadHandler<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::Dequeue(backend) => {
|
Msg::Dequeue(backend) => {
|
||||||
let index = self
|
let index = self.backends.iter().position(|b| b == &backend);
|
||||||
.backends
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.iter()
|
|
||||||
.position(|b| b == &backend);
|
|
||||||
if let Some(index) = index {
|
if let Some(index) = index {
|
||||||
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::Deflight(backend) => {
|
Msg::Deflight(backend) => {
|
||||||
let index = self
|
let index = self.backends.iter().position(|b| b == &backend);
|
||||||
.backends
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.iter()
|
|
||||||
.position(|b| b == &backend);
|
|
||||||
if let Some(index) = index {
|
if let Some(index) = index {
|
||||||
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Msg::AddBackend(backend) => {
|
Msg::AddBackend(backend) => {
|
||||||
let mut backends = self.backends.write().await;
|
self.backends.push(backend);
|
||||||
backends.push(backend);
|
self.backends.sort();
|
||||||
backends.sort();
|
|
||||||
}
|
}
|
||||||
Msg::RemoveBackend(backend) => {
|
Msg::RemoveBackend(backend) => {
|
||||||
let mut backends = self.backends.write().await;
|
self.backends.retain(|b| *b == backend);
|
||||||
backends.retain(|b| *b == backend);
|
self.backends.sort();
|
||||||
backends.sort();
|
|
||||||
}
|
}
|
||||||
Msg::SetBackends(backends) => {
|
Msg::SetBackends(backends) => {
|
||||||
*self.backends.write().await = backends;
|
self.backends = backends;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
routing::Router,
|
routing::Router,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
@ -8,17 +6,16 @@ use hyper::StatusCode;
|
|||||||
use kvrouter::{
|
use kvrouter::{
|
||||||
handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin,
|
handler, set_backends_handler, Communicator, ContentAware, OverloadHandler, RoundRobin,
|
||||||
};
|
};
|
||||||
use tokio::sync::RwLock;
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// List of backend servers
|
// List of backend servers
|
||||||
let backends = Arc::new(RwLock::new(vec![
|
let backends = vec![
|
||||||
// "http://localhost:8000".to_string(),
|
// "http://localhost:8000".to_string(),
|
||||||
// "http://localhost:8001".to_string(),
|
// "http://localhost:8001".to_string(),
|
||||||
// "http://localhost:8002".to_string(),
|
// "http://localhost:8002".to_string(),
|
||||||
// "http://localhost:8003".to_string(),
|
// "http://localhost:8003".to_string(),
|
||||||
]));
|
];
|
||||||
|
|
||||||
// Create a new instance of the RoundRobinRouter
|
// Create a new instance of the RoundRobinRouter
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user