mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
remove atomicbool
This commit is contained in:
parent
2c3772528d
commit
f40f02fc25
@ -1,8 +1,7 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Note: Request ids and batch ids cannot collide.
|
||||
const LIVENESS_ID: u64 = u64::MAX;
|
||||
@ -11,11 +10,11 @@ const BATCH_ID: u64 = u64::MAX;
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Health {
|
||||
client: ShardedClient,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
generation_health: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
@ -23,7 +22,7 @@ impl Health {
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
if self.generation_health.load(Ordering::SeqCst) {
|
||||
if *self.generation_health.borrow() {
|
||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||
self.client.health().await.is_ok()
|
||||
} else {
|
||||
@ -59,10 +58,7 @@ impl Health {
|
||||
max_tokens: 2,
|
||||
};
|
||||
// Skips the queue
|
||||
let value = self.client.prefill(batch).await.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
self.client.prefill(batch).await.is_ok()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,16 +4,13 @@ use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use futures::future::try_join_all;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
@ -50,7 +47,7 @@ impl Infer {
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
generation_health: watch::Sender<bool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size);
|
||||
@ -264,7 +261,7 @@ async fn batching_task(
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
generation_health: watch::Sender<bool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
@ -371,7 +368,7 @@ async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
generation_health: &watch::Sender<bool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
@ -380,7 +377,7 @@ async fn prefill(
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
let _ = generation_health.send(true);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
@ -394,7 +391,7 @@ async fn prefill(
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = generation_health.send(false);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
@ -408,7 +405,7 @@ async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
generation_health: &watch::Sender<bool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
@ -417,7 +414,7 @@ async fn decode(
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
let _ = generation_health.send(true);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
@ -430,7 +427,7 @@ async fn decode(
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = generation_health.send(false);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
|
@ -19,11 +19,10 @@ use futures::Stream;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time::Instant;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
@ -584,8 +583,8 @@ pub async fn run(
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
|
||||
let health_ext = Health::new(client.clone(), generation_health_receiver);
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
@ -596,7 +595,7 @@ pub async fn run(
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
generation_health,
|
||||
generation_health_sender,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
Loading…
Reference in New Issue
Block a user