feat: propagate max_concurrent_requests to queue state entries instead of hardcoded 128 in backends/v3

This commit is contained in:
Venkat Raman 2024-09-26 19:51:10 +02:00
parent 77ddc8309d
commit 45e060e857
4 changed files with 23 additions and 13 deletions

View File

@ -31,6 +31,7 @@ impl BackendV3 {
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
@ -46,6 +47,7 @@ impl BackendV3 {
let block_size = attention.block_size(); let block_size = attention.block_size();
let queue = Queue::new( let queue = Queue::new(
max_concurrent_requests,
requires_padding, requires_padding,
block_size, block_size,
prefix_caching, prefix_caching,

View File

@ -41,6 +41,7 @@ pub async fn connect_backend(
max_batch_total_tokens: Option<u32>, max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_concurrent_requests: usize,
) -> Result<(BackendV3, BackendInfo), V3Error> { ) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function // Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| { let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
@ -118,6 +119,7 @@ pub async fn connect_backend(
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size, shard_info.window_size,
shard_info.speculate, shard_info.speculate,

View File

@ -167,6 +167,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
max_concurrent_requests,
) )
.await?; .await?;

View File

@ -44,6 +44,7 @@ pub(crate) struct Queue {
impl Queue { impl Queue {
pub(crate) fn new( pub(crate) fn new(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool, prefix_caching: bool,
@ -56,6 +57,7 @@ impl Queue {
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task( tokio::spawn(queue_task(
max_concurrent_requests,
requires_padding, requires_padding,
block_size, block_size,
prefix_caching, prefix_caching,
@ -109,6 +111,7 @@ impl Queue {
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task( async fn queue_task(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool, prefix_caching: bool,
@ -118,6 +121,7 @@ async fn queue_task(
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new( let mut state = State::new(
max_concurrent_requests,
requires_padding, requires_padding,
block_size, block_size,
prefix_caching, prefix_caching,
@ -178,6 +182,7 @@ struct State {
impl State { impl State {
fn new( fn new(
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool, prefix_caching: bool,
@ -195,7 +200,7 @@ impl State {
}); });
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(max_concurrent_requests),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
block_size, block_size,
@ -567,7 +572,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(128, false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -583,7 +588,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(128, false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -591,7 +596,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -623,7 +628,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, false, None, 0, 16); let mut state = State::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -643,7 +648,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2); let mut state = State::new(128, false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -676,14 +681,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -691,7 +696,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -724,7 +729,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -740,7 +745,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -765,7 +770,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -784,7 +789,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16); let queue = Queue::new(128, false, 1, false, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);