From f40f02fc257b6296362f84f9a40f0b82de0219d8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 20 Oct 2023 17:39:28 +0200 Subject: [PATCH] remove atomicbool --- router/src/health.rs | 14 +++++--------- router/src/infer.rs | 23 ++++++++++------------- router/src/server.rs | 9 ++++----- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/router/src/health.rs b/router/src/health.rs index ab290fc1..788ff6fe 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,8 +1,7 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; +use tokio::sync::watch; // Note: Request ids and batch ids cannot collide. const LIVENESS_ID: u64 = u64::MAX; @@ -11,11 +10,11 @@ const BATCH_ID: u64 = u64::MAX; #[derive(Clone, Debug)] pub(crate) struct Health { client: ShardedClient, - generation_health: Arc, + generation_health: watch::Receiver, } impl Health { - pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { + pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver) -> Self { Self { client, generation_health, @@ -23,7 +22,7 @@ impl Health { } pub(crate) async fn check(&mut self) -> bool { - if self.generation_health.load(Ordering::SeqCst) { + if *self.generation_health.borrow() { // Generation is healthy, we only check that the shards are answering gRPC calls self.client.health().await.is_ok() } else { @@ -59,10 +58,7 @@ impl Health { max_tokens: 2, }; // Skips the queue - let value = self.client.prefill(batch).await.is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value + self.client.prefill(batch).await.is_ok() } } } diff --git a/router/src/infer.rs b/router/src/infer.rs index cc34c466..f7e62a25 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,16 +4,13 @@ use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; +use std::sync::Arc; use text_generation_client::{ Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -50,7 +47,7 @@ impl Infer { max_concurrent_requests: usize, requires_padding: bool, window_size: Option, - generation_health: Arc, + generation_health: watch::Sender, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding, 16, window_size); @@ -264,7 +261,7 @@ async fn batching_task( max_waiting_tokens: usize, queue: Queue, shared: Arc, - generation_health: Arc, + generation_health: watch::Sender, ) { // Infinite loop loop { @@ -371,7 +368,7 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, - generation_health: &Arc, + generation_health: &watch::Sender, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; @@ -380,7 +377,7 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { // Update health - generation_health.store(true, Ordering::SeqCst); + let _ = generation_health.send(true); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -394,7 +391,7 @@ async fn prefill( // If we have an error, we discard the whole batch Err(err) => { // Update health - generation_health.store(false, Ordering::SeqCst); + let _ = generation_health.send(false); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); @@ -408,7 +405,7 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, - generation_health: &Arc, + generation_health: &watch::Sender, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); @@ -417,7 +414,7 @@ async fn decode( match client.decode(batches).await { Ok((generations, next_batch)) => { // Update health - generation_health.store(true, Ordering::SeqCst); + let _ = generation_health.send(true); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -430,7 +427,7 @@ async fn decode( } // If we have an error, we discard the whole batch Err(err) => { - generation_health.store(false, Ordering::SeqCst); + let _ = generation_health.send(false); for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..fad731d3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,11 +19,10 @@ use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; +use tokio::sync::watch; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -584,8 +583,8 @@ pub async fn run( max_input_length, max_total_tokens, ); - let generation_health = Arc::new(AtomicBool::new(false)); - let health_ext = Health::new(client.clone(), generation_health.clone()); + let (generation_health_sender, generation_health_receiver) = watch::channel(false); + let health_ext = Health::new(client.clone(), generation_health_receiver); let infer = Infer::new( client, validation, @@ -596,7 +595,7 @@ pub async fn run( max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, - generation_health, + generation_health_sender, ); // Duration buckets