diff --git a/README.md b/README.md index 6afb47b2..11b4b2bf 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ Environment Variables Added: | SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command | | TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command | | WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command | +| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command | diff --git a/router/src/queue.rs b/router/src/queue.rs index 9961d27e..2b5f61b1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -3,7 +3,10 @@ use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; -use std::collections::VecDeque; +use std::cmp::{Eq, Ord, PartialEq, PartialOrd}; +use std::collections::BinaryHeap; +use std::env; +use std::time::Duration; use text_generation_client::{Batch, Request}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; @@ -132,11 +135,104 @@ async fn queue_task( } } +#[derive(Debug)] +struct IdentifiableEntry(u64, Entry); + +impl Eq for IdentifiableEntry {} + +impl PartialEq for IdentifiableEntry { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Ord for IdentifiableEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let ordering = match self + .1 + .request + .input_length + .cmp(&other.1.request.input_length) + { + std::cmp::Ordering::Equal => self.0.cmp(&other.0), + any => any, + }; + + // inverse to get min heap + return ordering.reverse(); + } +} + +impl PartialOrd for IdentifiableEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug)] +struct QueueImpl { + regular_entries: BinaryHeap, + overdue_entries: BinaryHeap, + overdue_threshold: Duration, +} + +impl QueueImpl { + fn new(capacity: usize, overdue_threshold: Duration) -> Self { + Self { + regular_entries: BinaryHeap::with_capacity(capacity), + overdue_entries: BinaryHeap::with_capacity(capacity), + overdue_threshold, + } + } + + fn update(&mut self) { + if self.regular_entries.is_empty() { + return; + } + + let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity()); + + for entry in self.regular_entries.drain() { + if entry.1.queue_time.elapsed() > self.overdue_threshold { + self.overdue_entries.push(entry); + } else { + left.push(entry); + } + } + + self.regular_entries = left; + } + + fn push(&mut self, entry: IdentifiableEntry) { + if entry.1.queue_time.elapsed() > self.overdue_threshold { + self.overdue_entries.push(entry); + } else { + self.regular_entries.push(entry); + } + } + + fn pop(&mut self) -> Option { + if !self.overdue_entries.is_empty() { + self.overdue_entries.pop() + } else { + self.regular_entries.pop() + } + } + + fn is_empty(&self) -> bool { + self.regular_entries.is_empty() && self.overdue_entries.is_empty() + } + + fn len(&self) -> usize { + self.regular_entries.len() + self.overdue_entries.len() + } +} + /// Queue State #[derive(Debug)] struct State { - /// Queue entries organized in a Vec - entries: VecDeque<(u64, Entry)>, + /// Queue entries + entries: QueueImpl, /// Id of the next entry next_id: u64, @@ -166,10 +262,16 @@ impl State { max_input_length: u32, max_total_tokens: u32, block_size: u32, - window_size: Option + window_size: Option, ) -> Self { + let default_threshold: u64 = 120; + let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") { + Ok(val) => val.parse().unwrap_or(default_threshold), + Err(_) => default_threshold, + }; + Self { - entries: VecDeque::with_capacity(128), + entries: QueueImpl::new(128, Duration::from_millis(threshold)), next_id: 0, next_batch_id: 0, requires_padding, @@ -187,7 +289,7 @@ impl State { entry.temp_span = Some(queue_span); // Push entry in the queue - self.entries.push_back((self.next_id, entry)); + self.entries.push(IdentifiableEntry(self.next_id, entry)); self.next_id += 1; } @@ -209,6 +311,8 @@ impl State { } } + self.entries.update(); + // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(&Span::current()); @@ -221,7 +325,7 @@ impl State { let mut decode_tokens: u32 = 0; // Pop entries starting from the front of the queue - while let Some((id, mut entry)) = self.entries.pop_front() { + while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -263,7 +367,7 @@ impl State { { // Entry is over budget // Add it back to the front - self.entries.push_front((id, entry)); + self.entries.push(IdentifiableEntry(id, entry)); break; } @@ -303,7 +407,7 @@ impl State { for r in batch_requests.into_iter().rev() { let id = r.id; let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); + self.entries.push(IdentifiableEntry(id, entry)); } return None; @@ -399,7 +503,7 @@ mod tests { assert_eq!(state.next_id, 1); assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); + let id = state.entries.pop().unwrap().0; assert_eq!(id, 0); } diff --git a/router/src/validation.rs b/router/src/validation.rs index 1b47fc97..17e72b87 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -2,6 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest}; use rand::{thread_rng, Rng}; +use std::env; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -21,6 +22,7 @@ pub struct Validation { max_total_tokens: usize, /// Channel to communicate with the background tokenization task sender: Option>, + skip_tokenizer_in_tgi: bool, } impl Validation { @@ -59,6 +61,10 @@ impl Validation { None }; + let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI") + .ok() + .map_or(false, |value| value.to_lowercase() == "true"); + Self { max_best_of, sender, @@ -66,6 +72,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, + skip_tokenizer_in_tgi, } } @@ -130,7 +137,11 @@ impl Validation { } else { return Err(ValidationError::UnsetMaxNewTokens); }; - let input_length = truncate.unwrap_or(self.max_input_length); + let input_length = if self.skip_tokenizer_in_tgi { + inputs.chars().filter(|&c| c == ',').count() + 1 + } else { + truncate.unwrap_or(self.max_input_length) + }; // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {