mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
back to atomicbool
This commit is contained in:
parent
3f1cc9bad7
commit
b2a5dd64c1
@ -1,7 +1,8 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Note: Request ids and batch ids cannot collide.
|
||||
const LIVENESS_ID: u64 = u64::MAX;
|
||||
@ -10,11 +11,11 @@ const BATCH_ID: u64 = u64::MAX;
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Health {
|
||||
client: ShardedClient,
|
||||
generation_health: watch::Receiver<bool>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
@ -22,7 +23,7 @@ impl Health {
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
if *self.generation_health.borrow() {
|
||||
if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||
self.client.health().await.is_ok()
|
||||
} else {
|
||||
@ -58,7 +59,10 @@ impl Health {
|
||||
max_tokens: 2,
|
||||
};
|
||||
// Skips the queue
|
||||
self.client.prefill(batch).await.is_ok()
|
||||
let value = self.client.prefill(batch).await.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,13 +4,16 @@ use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use futures::future::try_join_all;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
@ -47,7 +50,7 @@ impl Infer {
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
generation_health: watch::Sender<bool>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size);
|
||||
@ -261,7 +264,7 @@ async fn batching_task(
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
generation_health: watch::Sender<bool>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
@ -368,7 +371,7 @@ async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &watch::Sender<bool>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
@ -377,7 +380,7 @@ async fn prefill(
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
let _ = generation_health.send(true);
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
@ -391,7 +394,7 @@ async fn prefill(
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
let _ = generation_health.send(false);
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
@ -405,33 +408,29 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &watch::Sender<bool>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
||||
tracing::info!("Decode");
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
let _ = generation_health.send(true);
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
tracing::info!("filter send");
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
tracing::info!("filter batch");
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
tracing::info!("Decode ended");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = generation_health.send(false);
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
@ -466,13 +465,11 @@ async fn filter_batch(
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
tracing::info!("clear cache");
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
tracing::info!("filter batch call");
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -19,10 +19,11 @@ use futures::Stream;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time::Instant;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
@ -583,8 +584,8 @@ pub async fn run(
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
|
||||
let health_ext = Health::new(client.clone(), generation_health_receiver);
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
@ -595,7 +596,7 @@ pub async fn run(
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
generation_health_sender,
|
||||
generation_health,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
Loading…
Reference in New Issue
Block a user