mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Updating the kvrouter to support roundrobin
for comparison, still withthe overloading checks
This commit is contained in:
parent
6a88063cc2
commit
0a495ad118
@ -110,10 +110,15 @@ impl Allocator for RadixAllocator {
|
|||||||
|
|
||||||
let suffix_blocks = suffix_len.div_ceil(self.block_size);
|
let suffix_blocks = suffix_len.div_ceil(self.block_size);
|
||||||
|
|
||||||
|
let prefix_len_uncached = prefill_tokens.as_ref().map(|p| p.len()).unwrap_or_default();
|
||||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
metrics::counter!("tgi_cache_hit", "allocator" => "radix")
|
metrics::counter!("tgi_cache_hit", "allocator" => "radix")
|
||||||
.increment(prefix_len.try_into().expect("Can convert usize to u64"));
|
.increment(prefix_len.try_into().expect("Can convert usize to u64"));
|
||||||
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(suffix_len.into());
|
metrics::counter!("tgi_cache_total", "allocator" => "radix").increment(
|
||||||
|
prefix_len_uncached
|
||||||
|
.try_into()
|
||||||
|
.expect("Can convert usize to u64"),
|
||||||
|
);
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
|
@ -18,37 +18,25 @@ const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
|
|||||||
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RoundRobin {
|
pub struct ContentAware {
|
||||||
client: Client,
|
|
||||||
trie: Arc<Mutex<Trie>>,
|
trie: Arc<Mutex<Trie>>,
|
||||||
backends: Arc<Vec<String>>,
|
|
||||||
inqueue: Arc<Vec<AtomicUsize>>,
|
|
||||||
inflight: Arc<Vec<AtomicUsize>>,
|
|
||||||
factor: f32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RoundRobin {
|
impl Default for ContentAware {
|
||||||
pub fn new(backends: Vec<String>) -> Self {
|
fn default() -> Self {
|
||||||
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
Self::new()
|
||||||
.build(HttpConnector::new());
|
}
|
||||||
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
}
|
||||||
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
|
||||||
|
impl ContentAware {
|
||||||
|
pub fn new() -> Self {
|
||||||
let trie = Arc::new(Mutex::new(Trie::new()));
|
let trie = Arc::new(Mutex::new(Trie::new()));
|
||||||
let factor: f32 = std::env::var(FACTOR_KEY)
|
Self { trie }
|
||||||
.unwrap_or("1.5".to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap_or(1.5);
|
|
||||||
Self {
|
|
||||||
inflight,
|
|
||||||
inqueue,
|
|
||||||
trie,
|
|
||||||
client,
|
|
||||||
factor,
|
|
||||||
backends: Arc::new(backends),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next(&mut self, key: &[u8]) -> usize {
|
impl LoadBalancer for ContentAware {
|
||||||
|
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
|
||||||
let mut trie = self.trie.lock().unwrap();
|
let mut trie = self.trie.lock().unwrap();
|
||||||
let (start, stop) = trie.insert(key);
|
let (start, stop) = trie.insert(key);
|
||||||
let n = trie.count();
|
let n = trie.count();
|
||||||
@ -67,38 +55,108 @@ impl RoundRobin {
|
|||||||
assert!(rescaled_x >= 0.0);
|
assert!(rescaled_x >= 0.0);
|
||||||
assert!(rescaled_x <= 1.0);
|
assert!(rescaled_x <= 1.0);
|
||||||
println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}");
|
println!("Start {start:.2} stop {stop:.2}: rescaled {rescaled_x:.2}");
|
||||||
let n: usize = (rescaled_x * (self.backends.len() as f32)) as usize;
|
let n: usize = (rescaled_x * (n_backends as f32)) as usize;
|
||||||
n
|
n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Response<Body> {
|
#[derive(Clone)]
|
||||||
|
pub struct RoundRobin {
|
||||||
|
current: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RoundRobin {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RoundRobin {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let current = Arc::new(AtomicUsize::new(0));
|
||||||
|
Self { current }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoadBalancer for RoundRobin {
|
||||||
|
fn next(&mut self, _key: &[u8], _n_backends: usize) -> usize {
|
||||||
|
self.current.fetch_add(1, Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OverloadHandler<T: LoadBalancer> {
|
||||||
|
client: Client,
|
||||||
|
load_balancer: T,
|
||||||
|
backends: Arc<Vec<String>>,
|
||||||
|
inqueue: Arc<Vec<AtomicUsize>>,
|
||||||
|
inflight: Arc<Vec<AtomicUsize>>,
|
||||||
|
factor: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: LoadBalancer> OverloadHandler<T> {
|
||||||
|
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
|
||||||
|
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||||
|
.build(HttpConnector::new());
|
||||||
|
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||||
|
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
|
||||||
|
let factor: f32 = std::env::var(FACTOR_KEY)
|
||||||
|
.unwrap_or("1.5".to_string())
|
||||||
|
.parse()
|
||||||
|
.unwrap_or(1.5);
|
||||||
|
let backends = Arc::new(backends);
|
||||||
|
Self {
|
||||||
|
load_balancer,
|
||||||
|
backends,
|
||||||
|
client,
|
||||||
|
factor,
|
||||||
|
inflight,
|
||||||
|
inqueue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next(&mut self, key: &[u8]) -> usize {
|
||||||
|
// Get the backend URL
|
||||||
|
let index = self.load_balancer.next(key, self.backends.len());
|
||||||
|
let n = self.backends.len();
|
||||||
|
let mut index = index % n;
|
||||||
|
|
||||||
|
let mut inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
|
let mut inqueue = self.inqueue[index].load(Ordering::Relaxed);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
if (inqueue as f32) <= self.factor * inflight as f32 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
eprintln!(
|
||||||
|
"Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
index += 1;
|
||||||
|
index %= self.backends.len();
|
||||||
|
inflight = self.inflight[index].load(Ordering::Relaxed);
|
||||||
|
inqueue = self.inflight[index].load(Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LoadBalancer {
|
||||||
|
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handler<T: LoadBalancer>(
|
||||||
|
State(mut state): State<OverloadHandler<T>>,
|
||||||
|
req: Request,
|
||||||
|
) -> Response<Body> {
|
||||||
// Get the next backend index
|
// Get the next backend index
|
||||||
let limit = 2048usize;
|
let limit = 1024 * 1024;
|
||||||
let (parts, body) = req.into_parts();
|
let (parts, body) = req.into_parts();
|
||||||
// TODO
|
// TODO
|
||||||
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
|
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
|
||||||
let index = state.next(&bytes);
|
let index = state.next(&bytes);
|
||||||
// Get the backend URL
|
|
||||||
let n = state.backends.len();
|
|
||||||
let mut index = index % n;
|
|
||||||
let backend = &state.backends[index];
|
let backend = &state.backends[index];
|
||||||
|
|
||||||
let mut inflight = state.inflight[index].load(Ordering::Relaxed);
|
|
||||||
let mut inqueue = state.inqueue[index].load(Ordering::Relaxed);
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
if (inqueue as f32) <= state.factor * inflight as f32 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if i == 0 {
|
|
||||||
eprintln!("Backend overloaded (queue: {inqueue} inflight {inflight}), jumping ahead");
|
|
||||||
}
|
|
||||||
index += 1;
|
|
||||||
index %= state.backends.len();
|
|
||||||
inflight = state.inflight[index].load(Ordering::Relaxed);
|
|
||||||
inqueue = state.inflight[index].load(Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
state.inflight[index].fetch_add(1, Ordering::Relaxed);
|
state.inflight[index].fetch_add(1, Ordering::Relaxed);
|
||||||
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
@ -131,6 +189,7 @@ pub async fn handler(State(mut state): State<RoundRobin>, req: Request) -> Respo
|
|||||||
while let Some(raw_event) = response_stream.next().await {
|
while let Some(raw_event) = response_stream.next().await {
|
||||||
if start{
|
if start{
|
||||||
eprintln!("Not inqueue");
|
eprintln!("Not inqueue");
|
||||||
|
|
||||||
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
|
||||||
start = false;
|
start = false;
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use axum::{
|
|||||||
routing::Router,
|
routing::Router,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use kvrouter::{handler, RoundRobin};
|
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
@ -10,12 +10,16 @@ async fn main() {
|
|||||||
let backends = vec![
|
let backends = vec![
|
||||||
"http://localhost:8000".to_string(),
|
"http://localhost:8000".to_string(),
|
||||||
"http://localhost:8001".to_string(),
|
"http://localhost:8001".to_string(),
|
||||||
|
"http://localhost:8002".to_string(),
|
||||||
|
"http://localhost:8003".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Create a new instance of the RoundRobinRouter
|
// Create a new instance of the RoundRobinRouter
|
||||||
let router = RoundRobin::new(backends);
|
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
|
||||||
|
println!("Using round robin");
|
||||||
|
let lb = RoundRobin::new();
|
||||||
// Create the Axum router
|
// Create the Axum router
|
||||||
|
let router = OverloadHandler::new(lb, backends);
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/{*key}", get(handler))
|
.route("/{*key}", get(handler))
|
||||||
.route("/{*key}", post(handler))
|
.route("/{*key}", post(handler))
|
||||||
@ -27,4 +31,21 @@ async fn main() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
println!("listening on {}", listener.local_addr().unwrap());
|
println!("listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
} else {
|
||||||
|
println!("Using Content aware");
|
||||||
|
let lb = ContentAware::new();
|
||||||
|
// Create the Axum router
|
||||||
|
let router = OverloadHandler::new(lb, backends);
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/{*key}", get(handler))
|
||||||
|
.route("/{*key}", post(handler))
|
||||||
|
.with_state(router);
|
||||||
|
|
||||||
|
// run it
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
println!("listening on {}", listener.local_addr().unwrap());
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ export function get_options() {
|
|||||||
throughput: {
|
throughput: {
|
||||||
executor: 'shared-iterations',
|
executor: 'shared-iterations',
|
||||||
vus: 100,
|
vus: 100,
|
||||||
iterations: 200,
|
iterations: 500,
|
||||||
maxDuration: '40s',
|
maxDuration: '40s',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user