diff --git a/Cargo.lock b/Cargo.lock index 8fa7b7266..b1f7279a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -743,18 +743,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" -[[package]] -name = "flume" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "spin 0.9.8", -] - [[package]] name = "fnv" version = "1.0.7" @@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", ] [[package]] @@ -1508,15 +1494,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -2313,7 +2290,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin 0.5.2", + "spin", "untrusted", "web-sys", "winapi", @@ -2678,15 +2655,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2808,7 +2776,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "1.1.0" +version = "1.1.1" dependencies = [ "average", "clap", @@ -2829,7 +2797,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "1.1.0" +version = "1.1.1" dependencies = [ "futures", "grpc-metadata", @@ -2845,7 +2813,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "1.1.0" +version = "1.1.1" dependencies = [ "clap", "ctrlc", @@ -2861,13 +2829,12 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "1.1.0" +version = "1.1.1" dependencies = [ "async-stream", "axum", "axum-tracing-opentelemetry", "clap", - "flume", "futures", "hf-hub 0.3.1", "init-tracing-opentelemetry", @@ -2885,6 +2852,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tower-http", "tracing", "tracing-opentelemetry", diff --git a/router/Cargo.toml b/router/Cargo.toml index 87b5a8d39..55af635a2 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] } axum-tracing-opentelemetry = "0.14.1" text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } -flume = "0.11.0" futures = "0.3.28" metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } @@ -34,6 +33,7 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { version = "0.14.0", features = ["http"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" tower-http = { version = "0.4.4", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index f8f5df957..341e70fd5 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -107,15 +107,14 @@ impl Client { ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); - let mut truncate = 0; // Create requests while n_tokens < max_prefill_tokens { - truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); requests.push(Request { id: 0, // We truncate the input on the server side to be sure that it has the correct size inputs: "_test ".to_string().repeat(max_input_length as usize), - truncate: truncate, + truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/router/src/infer.rs b/router/src/infer.rs index 787ccfcf1..cc34c466b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,22 +2,21 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; -use flume::r#async::RecvStream; -use flume::SendTimeoutError; use futures::future::try_join_all; -use futures::stream::StreamExt; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use std::time::Duration; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; -use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; use tracing::{info_span, instrument, Instrument, Span}; /// Inference struct @@ -90,7 +89,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, - RecvStream>, + UnboundedReceiverStream>, ), InferError, > { @@ -113,7 +112,7 @@ impl Infer { })?; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = flume::unbounded(); + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Append the request to the queue self.queue.append(Entry { @@ -130,7 +129,7 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, response_rx.into_stream())) + Ok((permit, UnboundedReceiverStream::new(response_rx))) } /// Add a new request to the queue and return a InferResponse @@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); err }).unwrap_or(true); @@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>>> { +) -> Result>>> { // Return directly if the channel is disconnected - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); return Ok(true); } @@ -520,10 +517,9 @@ fn send_responses( if let Some(prefill_tokens) = generation.prefill_tokens { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Prefill(prefill_tokens)), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; } // Create last Token @@ -558,22 +554,18 @@ fn send_responses( // Generation has ended stopped = true; // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }), - Duration::from_millis(10), - )?; + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text, + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; } else { // Send message - entry.response_tx.send_timeout( - Ok(InferStreamResponse::Intermediate { token, top_tokens }), - Duration::from_millis(10), - )?; + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } Ok(stopped) } @@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send_timeout(Err(err), Duration::from_millis(10)) + .send(Err(err)) .unwrap_or(()); }); } diff --git a/router/src/queue.rs b/router/src/queue.rs index 1ab9eb11e..bbb8db0e4 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; use text_generation_client::{Batch, Request}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -15,7 +15,7 @@ pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: flume::Sender>, + pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... @@ -30,13 +30,13 @@ pub(crate) struct Entry { #[derive(Debug, Clone)] pub(crate) struct Queue { /// Channel to communicate with the background queue task - queue_sender: flume::Sender, + queue_sender: mpsc::UnboundedSender, } impl Queue { pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { // Create channel - let (queue_sender, queue_receiver) = flume::unbounded(); + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); // Launch background queue task tokio::spawn(queue_task( @@ -91,11 +91,11 @@ async fn queue_task( requires_padding: bool, block_size: u32, window_size: Option, - receiver: flume::Receiver, + mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new(requires_padding, block_size, window_size); - while let Ok(cmd) = receiver.recv_async().await { + while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); @@ -195,7 +195,7 @@ impl State { while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) - if entry.response_tx.is_disconnected() { + if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); continue; } @@ -321,9 +321,9 @@ mod tests { fn default_entry() -> ( Entry, - flume::Receiver>, + mpsc::UnboundedReceiver>, ) { - let (response_tx, receiver_tx) = flume::unbounded(); + let (response_tx, receiver_tx) = mpsc::unbounded_channel(); let entry = Entry { request: ValidGenerateRequest { diff --git a/router/src/validation.rs b/router/src/validation.rs index 37465272a..7a84640dc 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; +use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; @@ -19,7 +20,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: Option>, } impl Validation { @@ -34,19 +35,25 @@ impl Validation { ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { - // Create channel - let (validation_sender, validation_receiver) = flume::unbounded(); + // Create round robin channel + let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); + let mut senders = Vec::with_capacity(workers); // Create workers for _ in 0..workers { let tokenizer_clone = tokenizer.clone(); - let receiver_clone = validation_receiver.clone(); + let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); + senders.push(tokenizer_sender); // Spawn worker tokio::task::spawn_blocking(move || { - tokenizer_worker(tokenizer_clone, receiver_clone) + tokenizer_worker(tokenizer_clone, tokenizer_receiver) }); } + + // Create tokenization round robin task + tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); + Some(validation_sender) } else { None @@ -118,12 +125,10 @@ impl Validation { // We make sure that truncate + max_new_tokens <= self.max_total_tokens let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { max_new_tokens + } else if let Some(truncate) = truncate { + self.max_total_tokens.saturating_sub(truncate) as u32 } else { - if let Some(truncate) = truncate { - self.max_total_tokens.saturating_sub(truncate) as u32 - } else { - return Err(ValidationError::UnsetMaxNewTokens); - } + return Err(ValidationError::UnsetMaxNewTokens); }; let input_length = truncate.unwrap_or(self.max_input_length); @@ -309,10 +314,25 @@ impl Validation { } } +/// Round robin tokenization task +async fn round_robin_task( + mut receiver: mpsc::UnboundedReceiver, + senders: Vec>, +) { + loop { + for sender in &senders { + match receiver.recv().await { + None => return, + Some(request) => sender.send(request).unwrap(), + }; + } + } +} + /// Start tokenization workers -fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver) { +fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver) { // Loop over requests - while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { + while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx .send(prepare_input(inputs, truncate, &tokenizer))