back to atomicbool

This commit is contained in:
OlivierDehaene 2023-10-23 12:47:16 +02:00
parent 3f1cc9bad7
commit b2a5dd64c1
3 changed files with 27 additions and 25 deletions

View File

@ -1,7 +1,8 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use tokio::sync::watch;
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
@ -10,11 +11,11 @@ const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: watch::Receiver<bool>,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
@ -22,7 +23,7 @@ impl Health {
}
pub(crate) async fn check(&mut self) -> bool {
if *self.generation_health.borrow() {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
@ -58,7 +59,10 @@ impl Health {
max_tokens: 2,
};
// Skips the queue
self.client.prefill(batch).await.is_ok()
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}

View File

@ -4,13 +4,16 @@ use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
@ -47,7 +50,7 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
generation_health: watch::Sender<bool>,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size);
@ -261,7 +264,7 @@ async fn batching_task(
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
generation_health: watch::Sender<bool>,
generation_health: Arc<AtomicBool>,
) {
// Infinite loop
loop {
@ -368,7 +371,7 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &watch::Sender<bool>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
@ -377,7 +380,7 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
// Update health
let _ = generation_health.send(true);
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
@ -391,7 +394,7 @@ async fn prefill(
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
let _ = generation_health.send(false);
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -405,33 +408,29 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &watch::Sender<bool>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
tracing::info!("Decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Update health
let _ = generation_health.send(true);
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries
tracing::info!("filter send");
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
tracing::info!("filter batch");
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
tracing::info!("Decode ended");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = generation_health.send(false);
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
@ -466,13 +465,11 @@ async fn filter_batch(
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("clear cache");
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
tracing::info!("filter batch call");
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}

View File

@ -19,10 +19,11 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::watch;
use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
@ -583,8 +584,8 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
let health_ext = Health::new(client.clone(), generation_health_receiver);
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new(
client,
validation,
@ -595,7 +596,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
generation_health_sender,
generation_health,
);
// Duration buckets