Heap based router queue (#63) (#88)

Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-02-29 10:56:26 +01:00 committed by GitHub
parent 7dbf4bf7a4
commit 8f6564ce0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 11 deletions

View File

@ -87,6 +87,7 @@ Environment Variables Added:
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | | SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command | | TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command | | WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
</div> </div>

View File

@ -3,7 +3,10 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque; use std::cmp::{Eq, Ord, PartialEq, PartialOrd};
use std::collections::BinaryHeap;
use std::env;
use std::time::Duration;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
@ -132,11 +135,104 @@ async fn queue_task(
} }
} }
#[derive(Debug)]
struct IdentifiableEntry(u64, Entry);
impl Eq for IdentifiableEntry {}
impl PartialEq for IdentifiableEntry {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Ord for IdentifiableEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let ordering = match self
.1
.request
.input_length
.cmp(&other.1.request.input_length)
{
std::cmp::Ordering::Equal => self.0.cmp(&other.0),
any => any,
};
// inverse to get min heap
return ordering.reverse();
}
}
impl PartialOrd for IdentifiableEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug)]
struct QueueImpl {
regular_entries: BinaryHeap<IdentifiableEntry>,
overdue_entries: BinaryHeap<IdentifiableEntry>,
overdue_threshold: Duration,
}
impl QueueImpl {
fn new(capacity: usize, overdue_threshold: Duration) -> Self {
Self {
regular_entries: BinaryHeap::with_capacity(capacity),
overdue_entries: BinaryHeap::with_capacity(capacity),
overdue_threshold,
}
}
fn update(&mut self) {
if self.regular_entries.is_empty() {
return;
}
let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity());
for entry in self.regular_entries.drain() {
if entry.1.queue_time.elapsed() > self.overdue_threshold {
self.overdue_entries.push(entry);
} else {
left.push(entry);
}
}
self.regular_entries = left;
}
fn push(&mut self, entry: IdentifiableEntry) {
if entry.1.queue_time.elapsed() > self.overdue_threshold {
self.overdue_entries.push(entry);
} else {
self.regular_entries.push(entry);
}
}
fn pop(&mut self) -> Option<IdentifiableEntry> {
if !self.overdue_entries.is_empty() {
self.overdue_entries.pop()
} else {
self.regular_entries.pop()
}
}
fn is_empty(&self) -> bool {
self.regular_entries.is_empty() && self.overdue_entries.is_empty()
}
fn len(&self) -> usize {
self.regular_entries.len() + self.overdue_entries.len()
}
}
/// Queue State /// Queue State
#[derive(Debug)] #[derive(Debug)]
struct State { struct State {
/// Queue entries organized in a Vec /// Queue entries
entries: VecDeque<(u64, Entry)>, entries: QueueImpl,
/// Id of the next entry /// Id of the next entry
next_id: u64, next_id: u64,
@ -166,10 +262,16 @@ impl State {
max_input_length: u32, max_input_length: u32,
max_total_tokens: u32, max_total_tokens: u32,
block_size: u32, block_size: u32,
window_size: Option<u32> window_size: Option<u32>,
) -> Self { ) -> Self {
let default_threshold: u64 = 120;
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
Ok(val) => val.parse().unwrap_or(default_threshold),
Err(_) => default_threshold,
};
Self { Self {
entries: VecDeque::with_capacity(128), entries: QueueImpl::new(128, Duration::from_millis(threshold)),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
@ -187,7 +289,7 @@ impl State {
entry.temp_span = Some(queue_span); entry.temp_span = Some(queue_span);
// Push entry in the queue // Push entry in the queue
self.entries.push_back((self.next_id, entry)); self.entries.push(IdentifiableEntry(self.next_id, entry));
self.next_id += 1; self.next_id += 1;
} }
@ -209,6 +311,8 @@ impl State {
} }
} }
self.entries.update();
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(&Span::current());
@ -221,7 +325,7 @@ impl State {
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() { while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
@ -263,7 +367,7 @@ impl State {
{ {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
self.entries.push_front((id, entry)); self.entries.push(IdentifiableEntry(id, entry));
break; break;
} }
@ -303,7 +407,7 @@ impl State {
for r in batch_requests.into_iter().rev() { for r in batch_requests.into_iter().rev() {
let id = r.id; let id = r.id;
let entry = batch_entries.remove(&id).unwrap(); let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry)); self.entries.push(IdentifiableEntry(id, entry));
} }
return None; return None;
@ -399,7 +503,7 @@ mod tests {
assert_eq!(state.next_id, 1); assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap(); let id = state.entries.pop().unwrap().0;
assert_eq!(id, 0); assert_eq!(id, 0);
} }

View File

@ -2,6 +2,7 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::env;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -21,6 +22,7 @@ pub struct Validation {
max_total_tokens: usize, max_total_tokens: usize,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
skip_tokenizer_in_tgi: bool,
} }
impl Validation { impl Validation {
@ -59,6 +61,10 @@ impl Validation {
None None
}; };
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
.ok()
.map_or(false, |value| value.to_lowercase() == "true");
Self { Self {
max_best_of, max_best_of,
sender, sender,
@ -66,6 +72,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
skip_tokenizer_in_tgi,
} }
} }
@ -130,7 +137,11 @@ impl Validation {
} else { } else {
return Err(ValidationError::UnsetMaxNewTokens); return Err(ValidationError::UnsetMaxNewTokens);
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = if self.skip_tokenizer_in_tgi {
inputs.chars().filter(|&c| c == ',').count() + 1
} else {
truncate.unwrap_or(self.max_input_length)
};
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {