remove atomicbool

This commit is contained in:
OlivierDehaene 2023-10-20 17:39:28 +02:00
parent 2c3772528d
commit f40f02fc25
3 changed files with 19 additions and 27 deletions

View File

@ -1,8 +1,7 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
use tokio::sync::watch;
// Note: Request ids and batch ids cannot collide. // Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX; const LIVENESS_ID: u64 = u64::MAX;
@ -11,11 +10,11 @@ const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct Health { pub(crate) struct Health {
client: ShardedClient, client: ShardedClient,
generation_health: Arc<AtomicBool>, generation_health: watch::Receiver<bool>,
} }
impl Health { impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self { pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
Self { Self {
client, client,
generation_health, generation_health,
@ -23,7 +22,7 @@ impl Health {
} }
pub(crate) async fn check(&mut self) -> bool { pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) { if *self.generation_health.borrow() {
// Generation is healthy, we only check that the shards are answering gRPC calls // Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok() self.client.health().await.is_ok()
} else { } else {
@ -59,10 +58,7 @@ impl Health {
max_tokens: 2, max_tokens: 2,
}; };
// Skips the queue // Skips the queue
let value = self.client.prefill(batch).await.is_ok(); self.client.prefill(batch).await.is_ok()
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
} }
} }
} }

View File

@ -4,16 +4,13 @@ use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all; use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::Arc;
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{ use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -50,7 +47,7 @@ impl Infer {
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
generation_health: Arc<AtomicBool>, generation_health: watch::Sender<bool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16, window_size); let queue = Queue::new(requires_padding, 16, window_size);
@ -264,7 +261,7 @@ async fn batching_task(
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: Arc<AtomicBool>, generation_health: watch::Sender<bool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
@ -371,7 +368,7 @@ async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>, generation_health: &watch::Sender<bool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
@ -380,7 +377,7 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
generation_health.store(true, Ordering::SeqCst); let _ = generation_health.send(true);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -394,7 +391,7 @@ async fn prefill(
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
// Update health // Update health
generation_health.store(false, Ordering::SeqCst); let _ = generation_health.send(false);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -408,7 +405,7 @@ async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>, generation_health: &watch::Sender<bool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -417,7 +414,7 @@ async fn decode(
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
generation_health.store(true, Ordering::SeqCst); let _ = generation_health.send(true);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -430,7 +427,7 @@ async fn decode(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
generation_health.store(false, Ordering::SeqCst); let _ = generation_health.send(false);
for id in batch_ids { for id in batch_ids {
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }

View File

@ -19,11 +19,10 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::watch;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
@ -584,8 +583,8 @@ pub async fn run(
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
let generation_health = Arc::new(AtomicBool::new(false)); let (generation_health_sender, generation_health_receiver) = watch::channel(false);
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = Health::new(client.clone(), generation_health_receiver);
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -596,7 +595,7 @@ pub async fn run(
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size, shard_info.window_size,
generation_health, generation_health_sender,
); );
// Duration buckets // Duration buckets