back to atomicbool

This commit is contained in:
OlivierDehaene 2023-10-23 12:47:16 +02:00
parent 3f1cc9bad7
commit b2a5dd64c1
3 changed files with 27 additions and 25 deletions

View File

@ -1,7 +1,8 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
use tokio::sync::watch;
// Note: Request ids and batch ids cannot collide. // Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX; const LIVENESS_ID: u64 = u64::MAX;
@ -10,11 +11,11 @@ const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct Health { pub(crate) struct Health {
client: ShardedClient, client: ShardedClient,
generation_health: watch::Receiver<bool>, generation_health: Arc<AtomicBool>,
} }
impl Health { impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self { pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self { Self {
client, client,
generation_health, generation_health,
@ -22,7 +23,7 @@ impl Health {
} }
pub(crate) async fn check(&mut self) -> bool { pub(crate) async fn check(&mut self) -> bool {
if *self.generation_health.borrow() { if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls // Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok() self.client.health().await.is_ok()
} else { } else {
@ -58,7 +59,10 @@ impl Health {
max_tokens: 2, max_tokens: 2,
}; };
// Skips the queue // Skips the queue
self.client.prefill(batch).await.is_ok() let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
} }
} }
} }

View File

@ -4,13 +4,16 @@ use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all; use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{ use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -47,7 +50,7 @@ impl Infer {
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
generation_health: watch::Sender<bool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding, 16, window_size); let queue = Queue::new(requires_padding, 16, window_size);
@ -261,7 +264,7 @@ async fn batching_task(
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: watch::Sender<bool>, generation_health: Arc<AtomicBool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
@ -368,7 +371,7 @@ async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &watch::Sender<bool>, generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
@ -377,7 +380,7 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
let _ = generation_health.send(true); generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -391,7 +394,7 @@ async fn prefill(
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
// Update health // Update health
let _ = generation_health.send(false); generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -405,33 +408,29 @@ async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &watch::Sender<bool>, generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
tracing::info!("Decode");
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health // Update health
let _ = generation_health.send(true); generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
tracing::info!("filter send");
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
tracing::info!("filter batch");
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
tracing::info!("Decode ended");
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
let _ = generation_health.send(false); generation_health.store(false, Ordering::SeqCst);
for id in batch_ids { for id in batch_ids {
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
@ -466,13 +465,11 @@ async fn filter_batch(
// Next batch is now empty // Next batch is now empty
// Clear it from the Python shards cache // Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("clear cache");
client.clear_cache(Some(id)).await.unwrap(); client.clear_cache(Some(id)).await.unwrap();
None None
} else { } else {
// Filter Python shard cache // Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("filter batch call");
client.filter_batch(id, batch.request_ids).await.unwrap() client.filter_batch(id, batch.request_ids).await.unwrap()
} }
} }

View File

@ -19,10 +19,11 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::watch;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
@ -583,8 +584,8 @@ pub async fn run(
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
let (generation_health_sender, generation_health_receiver) = watch::channel(false); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health_receiver); let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -595,7 +596,7 @@ pub async fn run(
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size, shard_info.window_size,
generation_health_sender, generation_health,
); );
// Duration buckets // Duration buckets