diff --git a/router/src/health.rs b/router/src/health.rs index 788ff6fe..ab290fc1 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,7 +1,8 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; -use tokio::sync::watch; // Note: Request ids and batch ids cannot collide. const LIVENESS_ID: u64 = u64::MAX; @@ -10,11 +11,11 @@ const BATCH_ID: u64 = u64::MAX; #[derive(Clone, Debug)] pub(crate) struct Health { client: ShardedClient, - generation_health: watch::Receiver, + generation_health: Arc, } impl Health { - pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver) -> Self { + pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { Self { client, generation_health, @@ -22,7 +23,7 @@ impl Health { } pub(crate) async fn check(&mut self) -> bool { - if *self.generation_health.borrow() { + if self.generation_health.load(Ordering::SeqCst) { // Generation is healthy, we only check that the shards are answering gRPC calls self.client.health().await.is_ok() } else { @@ -58,7 +59,10 @@ impl Health { max_tokens: 2, }; // Skips the queue - self.client.prefill(batch).await.is_ok() + let value = self.client.prefill(batch).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + value } } } diff --git a/router/src/infer.rs b/router/src/infer.rs index eaf4204f..cc34c466 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,13 +4,16 @@ use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; use nohash_hasher::IntMap; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -47,7 +50,7 @@ impl Infer { max_concurrent_requests: usize, requires_padding: bool, window_size: Option, - generation_health: watch::Sender, + generation_health: Arc, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size); @@ -261,7 +264,7 @@ async fn batching_task( max_waiting_tokens: usize, queue: Queue, shared: Arc, - generation_health: watch::Sender, + generation_health: Arc, ) { // Infinite loop loop { @@ -368,7 +371,7 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, - generation_health: &watch::Sender, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; @@ -377,7 +380,7 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { // Update health - let _ = generation_health.send(true); + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -391,7 +394,7 @@ async fn prefill( // If we have an error, we discard the whole batch Err(err) => { // Update health - let _ = generation_health.send(false); + generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); @@ -405,33 +408,29 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, - generation_health: &watch::Sender, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); - tracing::info!("Decode"); match client.decode(batches).await { Ok((generations, next_batch)) => { // Update health - let _ = generation_health.send(true); + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries - tracing::info!("filter send"); filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - tracing::info!("filter batch"); let next_batch = filter_batch(client, next_batch, entries).await; metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); - tracing::info!("Decode ended"); next_batch } // If we have an error, we discard the whole batch Err(err) => { - let _ = generation_health.send(false); + generation_health.store(false, Ordering::SeqCst); for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } @@ -466,13 +465,11 @@ async fn filter_batch( // Next batch is now empty // Clear it from the Python shards cache // We unwrap here as we need to panic since we cannot recover if this method fails - tracing::info!("clear cache"); client.clear_cache(Some(id)).await.unwrap(); None } else { // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails - tracing::info!("filter batch call"); client.filter_batch(id, batch.request_ids).await.unwrap() } } diff --git a/router/src/server.rs b/router/src/server.rs index fad731d3..f254afd8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,10 +19,11 @@ use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; -use tokio::sync::watch; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -583,8 +584,8 @@ pub async fn run( max_input_length, max_total_tokens, ); - let (generation_health_sender, generation_health_receiver) = watch::channel(false); - let health_ext = Health::new(client.clone(), generation_health_receiver); + let generation_health = Arc::new(AtomicBool::new(false)); + let health_ext = Health::new(client.clone(), generation_health.clone()); let infer = Infer::new( client, validation, @@ -595,7 +596,7 @@ pub async fn run( max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, - generation_health_sender, + generation_health, ); // Duration buckets