remove atomicbool

This commit is contained in:
OlivierDehaene 2023-10-20 17:39:28 +02:00
parent 2c3772528d
commit f40f02fc25
3 changed files with 19 additions and 27 deletions

View File

@ -1,8 +1,7 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use tokio::sync::watch;
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
@ -11,11 +10,11 @@ const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
generation_health: watch::Receiver<bool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
Self {
client,
generation_health,
@ -23,7 +22,7 @@ impl Health {
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
if *self.generation_health.borrow() {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
@ -59,10 +58,7 @@ impl Health {
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
self.client.prefill(batch).await.is_ok()
}
}
}

View File

@ -4,16 +4,13 @@ use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::sync::Arc;
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
@ -50,7 +47,7 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
generation_health: watch::Sender<bool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size);
@ -264,7 +261,7 @@ async fn batching_task(
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
generation_health: Arc<AtomicBool>,
generation_health: watch::Sender<bool>,
) {
// Infinite loop
loop {
@ -371,7 +368,7 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
generation_health: &watch::Sender<bool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
@ -380,7 +377,7 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let _ = generation_health.send(true);
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
@ -394,7 +391,7 @@ async fn prefill(
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = generation_health.send(false);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -408,7 +405,7 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
generation_health: &watch::Sender<bool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -417,7 +414,7 @@ async fn decode(
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let _ = generation_health.send(true);
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
@ -430,7 +427,7 @@ async fn decode(
}
// If we have an error, we discard the whole batch
Err(err) => {
generation_health.store(false, Ordering::SeqCst);
let _ = generation_health.send(false);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}

View File

@ -19,11 +19,10 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::watch;
use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
@ -584,8 +583,8 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
let health_ext = Health::new(client.clone(), generation_health_receiver);
let infer = Infer::new(
client,
validation,
@ -596,7 +595,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
generation_health,
generation_health_sender,
);
// Duration buckets