mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
This commit is contained in:
parent
7dbf4bf7a4
commit
8f6564ce0e
@ -87,6 +87,7 @@ Environment Variables Added:
|
|||||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||||
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
|
||||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||||
|
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,10 @@ use crate::infer::InferStreamResponse;
|
|||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::cmp::{Eq, Ord, PartialEq, PartialOrd};
|
||||||
|
use std::collections::BinaryHeap;
|
||||||
|
use std::env;
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -132,11 +135,104 @@ async fn queue_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct IdentifiableEntry(u64, Entry);
|
||||||
|
|
||||||
|
impl Eq for IdentifiableEntry {}
|
||||||
|
|
||||||
|
impl PartialEq for IdentifiableEntry {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.0 == other.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for IdentifiableEntry {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
let ordering = match self
|
||||||
|
.1
|
||||||
|
.request
|
||||||
|
.input_length
|
||||||
|
.cmp(&other.1.request.input_length)
|
||||||
|
{
|
||||||
|
std::cmp::Ordering::Equal => self.0.cmp(&other.0),
|
||||||
|
any => any,
|
||||||
|
};
|
||||||
|
|
||||||
|
// inverse to get min heap
|
||||||
|
return ordering.reverse();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for IdentifiableEntry {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct QueueImpl {
|
||||||
|
regular_entries: BinaryHeap<IdentifiableEntry>,
|
||||||
|
overdue_entries: BinaryHeap<IdentifiableEntry>,
|
||||||
|
overdue_threshold: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QueueImpl {
|
||||||
|
fn new(capacity: usize, overdue_threshold: Duration) -> Self {
|
||||||
|
Self {
|
||||||
|
regular_entries: BinaryHeap::with_capacity(capacity),
|
||||||
|
overdue_entries: BinaryHeap::with_capacity(capacity),
|
||||||
|
overdue_threshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self) {
|
||||||
|
if self.regular_entries.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity());
|
||||||
|
|
||||||
|
for entry in self.regular_entries.drain() {
|
||||||
|
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||||
|
self.overdue_entries.push(entry);
|
||||||
|
} else {
|
||||||
|
left.push(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.regular_entries = left;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, entry: IdentifiableEntry) {
|
||||||
|
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||||
|
self.overdue_entries.push(entry);
|
||||||
|
} else {
|
||||||
|
self.regular_entries.push(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pop(&mut self) -> Option<IdentifiableEntry> {
|
||||||
|
if !self.overdue_entries.is_empty() {
|
||||||
|
self.overdue_entries.pop()
|
||||||
|
} else {
|
||||||
|
self.regular_entries.pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_empty(&self) -> bool {
|
||||||
|
self.regular_entries.is_empty() && self.overdue_entries.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.regular_entries.len() + self.overdue_entries.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Queue State
|
/// Queue State
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct State {
|
struct State {
|
||||||
/// Queue entries organized in a Vec
|
/// Queue entries
|
||||||
entries: VecDeque<(u64, Entry)>,
|
entries: QueueImpl,
|
||||||
|
|
||||||
/// Id of the next entry
|
/// Id of the next entry
|
||||||
next_id: u64,
|
next_id: u64,
|
||||||
@ -166,10 +262,16 @@ impl State {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>
|
window_size: Option<u32>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let default_threshold: u64 = 120;
|
||||||
|
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
|
||||||
|
Ok(val) => val.parse().unwrap_or(default_threshold),
|
||||||
|
Err(_) => default_threshold,
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: QueueImpl::new(128, Duration::from_millis(threshold)),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
@ -187,7 +289,7 @@ impl State {
|
|||||||
entry.temp_span = Some(queue_span);
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
// Push entry in the queue
|
// Push entry in the queue
|
||||||
self.entries.push_back((self.next_id, entry));
|
self.entries.push(IdentifiableEntry(self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,6 +311,8 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.entries.update();
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(&Span::current());
|
next_batch_span.follows_from(&Span::current());
|
||||||
@ -221,7 +325,7 @@ impl State {
|
|||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
@ -263,7 +367,7 @@ impl State {
|
|||||||
{
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push(IdentifiableEntry(id, entry));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,7 +407,7 @@ impl State {
|
|||||||
for r in batch_requests.into_iter().rev() {
|
for r in batch_requests.into_iter().rev() {
|
||||||
let id = r.id;
|
let id = r.id;
|
||||||
let entry = batch_entries.remove(&id).unwrap();
|
let entry = batch_entries.remove(&id).unwrap();
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push(IdentifiableEntry(id, entry));
|
||||||
}
|
}
|
||||||
|
|
||||||
return None;
|
return None;
|
||||||
@ -399,7 +503,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(state.next_id, 1);
|
assert_eq!(state.next_id, 1);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
let id = state.entries.pop().unwrap().0;
|
||||||
assert_eq!(id, 0);
|
assert_eq!(id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
use std::env;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
@ -21,6 +22,7 @@ pub struct Validation {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
|
skip_tokenizer_in_tgi: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
@ -59,6 +61,10 @@ impl Validation {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
|
||||||
|
.ok()
|
||||||
|
.map_or(false, |value| value.to_lowercase() == "true");
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
max_best_of,
|
max_best_of,
|
||||||
sender,
|
sender,
|
||||||
@ -66,6 +72,7 @@ impl Validation {
|
|||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
skip_tokenizer_in_tgi,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +137,11 @@ impl Validation {
|
|||||||
} else {
|
} else {
|
||||||
return Err(ValidationError::UnsetMaxNewTokens);
|
return Err(ValidationError::UnsetMaxNewTokens);
|
||||||
};
|
};
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
let input_length = if self.skip_tokenizer_in_tgi {
|
||||||
|
inputs.chars().filter(|&c| c == ',').count() + 1
|
||||||
|
} else {
|
||||||
|
truncate.unwrap_or(self.max_input_length)
|
||||||
|
};
|
||||||
|
|
||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
|
Loading…
Reference in New Issue
Block a user