mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Updating the kvrouter to support roundrobin
for comparison, still withthe overloading checks
This commit is contained in:
parent
6a88063cc2
commit
0a495ad118
@ -110,10 +110,15 @@ impl Allocator for RadixAllocator {
|
||||
|
||||
let suffix_blocks = suffix_len.div_ceil(self.block_size);
|
||||
|
||||
let prefix_len_uncached = prefill_tokens.as_ref().map(|p| p.len()).unwrap_or_default();
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
metrics::counter!("tgi_cache_hit", "allocator" => "radix")
|
||||
.increment(prefix_len.try_into().expect("Can convert usize to u64"));
|
||||
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into());
|
||||
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(
|
||||
prefix_len_uncached
|
||||
.try_into()
|
||||
.expect("Can convert usize to u64"),
|
||||
);
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
|
@ -18,37 +18,25 @@ const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
|
||||
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RoundRobin {
|
||||
client: Client,
|
||||
pub struct ContentAware {
|
||||
trie: Arc<Mutex<Trie>>,
|
||||
backends: Arc<Vec<String>>,
|
||||
inqueue: Arc<Vec<AtomicUsize>>,
|
||||
inflight: Arc<Vec<AtomicUsize>>,
|
||||
factor: f32,
|
||||
}
|
||||
|
||||
impl RoundRobin {
|
||||
pub fn new(backends: Vec<String>) -> Self {
|
||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
let trie = Arc::new(Mutex::new(Trie::new()));
|
||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||
.unwrap_or("1.5".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1.5);
|
||||
Self {
|
||||
inflight,
|
||||
inqueue,
|
||||
trie,
|
||||
client,
|
||||
factor,
|
||||
backends: Arc::new(backends),
|
||||
}
|
||||
impl Default for ContentAware {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next(&mut self, key: &[u8]) -> usize {
|
||||
impl ContentAware {
|
||||
pub fn new() -> Self {
|
||||
let trie = Arc::new(Mutex::new(Trie::new()));
|
||||
Self { trie }
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancer for ContentAware {
|
||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
|
||||
let mut trie = self.trie.lock().unwrap();
|
||||
let (start, stop) = trie.insert(key);
|
||||
let n = trie.count();
|
||||
@ -67,38 +55,108 @@ impl RoundRobin {
|
||||
assert!(rescaled_x >= 0.0);
|
||||
assert!(rescaled_x <= 1.0);
|
||||
println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}");
|
||||
let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize;
|
||||
let n: usize = (rescaled_x * (n_backends as f32)) as usize;
|
||||
n
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Response<Body> {
|
||||
#[derive(Clone)]
|
||||
pub struct RoundRobin {
|
||||
current: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl Default for RoundRobin {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RoundRobin {
|
||||
pub fn new() -> Self {
|
||||
let current = Arc::new(AtomicUsize::new(0));
|
||||
Self { current }
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadBalancer for RoundRobin {
|
||||
fn next(&mut self, _key: &[u8], _n_backends: usize) -> usize {
|
||||
self.current.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OverloadHandler<T: LoadBalancer> {
|
||||
client: Client,
|
||||
load_balancer: T,
|
||||
backends: Arc<Vec<String>>,
|
||||
inqueue: Arc<Vec<AtomicUsize>>,
|
||||
inflight: Arc<Vec<AtomicUsize>>,
|
||||
factor: f32,
|
||||
}
|
||||
|
||||
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
|
||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||
.unwrap_or("1.5".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1.5);
|
||||
let backends = Arc::new(backends);
|
||||
Self {
|
||||
load_balancer,
|
||||
backends,
|
||||
client,
|
||||
factor,
|
||||
inflight,
|
||||
inqueue,
|
||||
}
|
||||
}
|
||||
|
||||
fn next(&mut self, key: &[u8]) -> usize {
|
||||
// Get the backend URL
|
||||
let index = self.load_balancer.next(key, self.backends.len());
|
||||
let n = self.backends.len();
|
||||
let mut index = index % n;
|
||||
|
||||
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||
let mut inqueue = self.inqueue[index].load(Ordering::Relaxed);
|
||||
|
||||
for i in 0..n {
|
||||
if (inqueue as f32) <= self.factor * inflight as f32 {
|
||||
break;
|
||||
}
|
||||
if i == 0 {
|
||||
eprintln!(
|
||||
"Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"
|
||||
);
|
||||
}
|
||||
index += 1;
|
||||
index %= self.backends.len();
|
||||
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||
}
|
||||
index
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LoadBalancer {
|
||||
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
|
||||
}
|
||||
|
||||
pub async fn handler<T: LoadBalancer>(
|
||||
State(mut state): State<OverloadHandler<T>>,
|
||||
req: Request,
|
||||
) -> Response<Body> {
|
||||
// Get the next backend index
|
||||
let limit = 2048usize;
|
||||
let limit = 1024 * 1024;
|
||||
let (parts, body) = req.into_parts();
|
||||
// TODO
|
||||
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
|
||||
let index = state.next(&bytes);
|
||||
// Get the backend URL
|
||||
let n = state.backends.len();
|
||||
let mut index = index % n;
|
||||
let backend = &state.backends[index];
|
||||
|
||||
let mut inflight = state.inflight[index].load(Ordering::Relaxed);
|
||||
let mut inqueue = state.inqueue[index].load(Ordering::Relaxed);
|
||||
|
||||
for i in 0..n {
|
||||
if (inqueue as f32) <= state.factor * inflight as f32 {
|
||||
break;
|
||||
}
|
||||
if i == 0 {
|
||||
eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead");
|
||||
}
|
||||
index += 1;
|
||||
index %= state.backends.len();
|
||||
inflight = state.inflight[index].load(Ordering::Relaxed);
|
||||
inqueue = state.inflight[index].load(Ordering::Relaxed);
|
||||
}
|
||||
state.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
@ -131,6 +189,7 @@ pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Respo
|
||||
while let Some(raw_event) = response_stream.next().await {
|
||||
if start{
|
||||
eprintln!("Not inqueue");
|
||||
|
||||
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||
start = false;
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use axum::{
|
||||
routing::Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
use kvrouter::{handler, RoundRobin};
|
||||
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@ -10,12 +10,16 @@ async fn main() {
|
||||
let backends = vec![
|
||||
"http://localhost:8000".to_string(),
|
||||
"http://localhost:8001".to_string(),
|
||||
"http://localhost:8002".to_string(),
|
||||
"http://localhost:8003".to_string(),
|
||||
];
|
||||
|
||||
// Create a new instance of the RoundRobinRouter
|
||||
let router = RoundRobin::new(backends);
|
||||
|
||||
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||
println!("Using round robin");
|
||||
let lb = RoundRobin::new();
|
||||
// Create the Axum router
|
||||
let router = OverloadHandler::new(lb, backends);
|
||||
let app = Router::new()
|
||||
.route("/{*key}", get(handler))
|
||||
.route("/{*key}", post(handler))
|
||||
@ -27,4 +31,21 @@ async fn main() {
|
||||
.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
} else {
|
||||
println!("Using Content aware");
|
||||
let lb = ContentAware::new();
|
||||
// Create the Axum router
|
||||
let router = OverloadHandler::new(lb, backends);
|
||||
let app = Router::new()
|
||||
.route("/{*key}", get(handler))
|
||||
.route("/{*key}", post(handler))
|
||||
.with_state(router);
|
||||
|
||||
// run it
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
println!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
};
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ export function get_options() {
|
||||
throughput: {
|
||||
executor: 'shared-iterations',
|
||||
vus: 100,
|
||||
iterations: 200,
|
||||
iterations: 500,
|
||||
maxDuration: '40s',
|
||||
},
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user