mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
This commit is contained in:
parent
7dbf4bf7a4
commit
8f6564ce0e
@ -87,6 +87,7 @@ Environment Variables Added:
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||
</div>
|
||||
|
||||
|
||||
|
@ -3,7 +3,10 @@ use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use std::cmp::{Eq, Ord, PartialEq, PartialOrd};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
@ -132,11 +135,104 @@ async fn queue_task(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IdentifiableEntry(u64, Entry);
|
||||
|
||||
impl Eq for IdentifiableEntry {}
|
||||
|
||||
impl PartialEq for IdentifiableEntry {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for IdentifiableEntry {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
let ordering = match self
|
||||
.1
|
||||
.request
|
||||
.input_length
|
||||
.cmp(&other.1.request.input_length)
|
||||
{
|
||||
std::cmp::Ordering::Equal => self.0.cmp(&other.0),
|
||||
any => any,
|
||||
};
|
||||
|
||||
// inverse to get min heap
|
||||
return ordering.reverse();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for IdentifiableEntry {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct QueueImpl {
|
||||
regular_entries: BinaryHeap<IdentifiableEntry>,
|
||||
overdue_entries: BinaryHeap<IdentifiableEntry>,
|
||||
overdue_threshold: Duration,
|
||||
}
|
||||
|
||||
impl QueueImpl {
|
||||
fn new(capacity: usize, overdue_threshold: Duration) -> Self {
|
||||
Self {
|
||||
regular_entries: BinaryHeap::with_capacity(capacity),
|
||||
overdue_entries: BinaryHeap::with_capacity(capacity),
|
||||
overdue_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self) {
|
||||
if self.regular_entries.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity());
|
||||
|
||||
for entry in self.regular_entries.drain() {
|
||||
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||
self.overdue_entries.push(entry);
|
||||
} else {
|
||||
left.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
self.regular_entries = left;
|
||||
}
|
||||
|
||||
fn push(&mut self, entry: IdentifiableEntry) {
|
||||
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||
self.overdue_entries.push(entry);
|
||||
} else {
|
||||
self.regular_entries.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Option<IdentifiableEntry> {
|
||||
if !self.overdue_entries.is_empty() {
|
||||
self.overdue_entries.pop()
|
||||
} else {
|
||||
self.regular_entries.pop()
|
||||
}
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.regular_entries.is_empty() && self.overdue_entries.is_empty()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.regular_entries.len() + self.overdue_entries.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
/// Queue entries
|
||||
entries: QueueImpl,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
@ -166,10 +262,16 @@ impl State {
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
let default_threshold: u64 = 120;
|
||||
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
|
||||
Ok(val) => val.parse().unwrap_or(default_threshold),
|
||||
Err(_) => default_threshold,
|
||||
};
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
entries: QueueImpl::new(128, Duration::from_millis(threshold)),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
@ -187,7 +289,7 @@ impl State {
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.entries.push(IdentifiableEntry(self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
}
|
||||
|
||||
@ -209,6 +311,8 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
self.entries.update();
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
@ -221,7 +325,7 @@ impl State {
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
@ -263,7 +367,7 @@ impl State {
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
self.entries.push_front((id, entry));
|
||||
self.entries.push(IdentifiableEntry(id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -303,7 +407,7 @@ impl State {
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
self.entries.push(IdentifiableEntry(id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
@ -399,7 +503,7 @@ mod tests {
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
let id = state.entries.pop().unwrap().0;
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::env;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
@ -21,6 +22,7 @@ pub struct Validation {
|
||||
max_total_tokens: usize,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
skip_tokenizer_in_tgi: bool,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
@ -59,6 +61,10 @@ impl Validation {
|
||||
None
|
||||
};
|
||||
|
||||
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
|
||||
.ok()
|
||||
.map_or(false, |value| value.to_lowercase() == "true");
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender,
|
||||
@ -66,6 +72,7 @@ impl Validation {
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
skip_tokenizer_in_tgi,
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,7 +137,11 @@ impl Validation {
|
||||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let input_length = if self.skip_tokenizer_in_tgi {
|
||||
inputs.chars().filter(|&c| c == ',').count() + 1
|
||||
} else {
|
||||
truncate.unwrap_or(self.max_input_length)
|
||||
};
|
||||
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
|
Loading…
Reference in New Issue
Block a user