From bd3a9d8e856cb7e2122f1a09d2fb0f44b7649dad Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 23 Jun 2023 14:58:28 +0200 Subject: [PATCH] fix(router): add timeout on flume sends (#488) --- router/src/infer.rs | 45 ++++++++++++++++++++++++++++++--------------- router/src/queue.rs | 2 +- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 00fa2818..f738f986 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; -use flume::SendError; +use flume::SendTimeoutError; use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; @@ -11,6 +11,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use std::time::Duration; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; @@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); err }).unwrap_or(true); @@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>> { +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_disconnected() { + return Ok(true); + } + let mut stopped = false; if let Some(prefill_tokens) = generation.prefill_tokens { // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Prefill(prefill_tokens)), + Duration::from_millis(10), + )?; } // Create last Token @@ -507,17 +518,21 @@ fn send_responses( // Generation has ended stopped = true; // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::End { + token, + generated_text, + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }), + Duration::from_millis(10), + )?; } else { // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Token(token)))?; + entry.response_tx.send_timeout( + Ok(InferStreamResponse::Token(token)), + Duration::from_millis(10), + )?; } Ok(stopped) } @@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // unwrap_or is valid here as we don't care if the receiver is gone. entry .response_tx - .send(Err(err)) + .send_timeout(Err(err), Duration::from_millis(10)) .unwrap_or(()); }); } diff --git a/router/src/queue.rs b/router/src/queue.rs index 0586083d..6d1d4d12 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver span.in_scope(|| { let next_batch = state.next_batch(min_size, token_budget); - response_sender.send(next_batch).unwrap_or(()); + response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size", state.entries.len() as f64); }), }