mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
remove atomicbool
This commit is contained in:
parent
2c3772528d
commit
f40f02fc25
@ -1,8 +1,7 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
// Note: Request ids and batch ids cannot collide.
|
// Note: Request ids and batch ids cannot collide.
|
||||||
const LIVENESS_ID: u64 = u64::MAX;
|
const LIVENESS_ID: u64 = u64::MAX;
|
||||||
@ -11,11 +10,11 @@ const BATCH_ID: u64 = u64::MAX;
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct Health {
|
pub(crate) struct Health {
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: watch::Receiver<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Health {
|
impl Health {
|
||||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
generation_health,
|
generation_health,
|
||||||
@ -23,7 +22,7 @@ impl Health {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
if self.generation_health.load(Ordering::SeqCst) {
|
if *self.generation_health.borrow() {
|
||||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||||
self.client.health().await.is_ok()
|
self.client.health().await.is_ok()
|
||||||
} else {
|
} else {
|
||||||
@ -59,10 +58,7 @@ impl Health {
|
|||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
};
|
};
|
||||||
// Skips the queue
|
// Skips the queue
|
||||||
let value = self.client.prefill(batch).await.is_ok();
|
self.client.prefill(batch).await.is_ok()
|
||||||
// Update generation health
|
|
||||||
self.generation_health.store(value, Ordering::SeqCst);
|
|
||||||
value
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,16 +4,13 @@ use crate::{Entry, Queue, Token};
|
|||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::{
|
use std::sync::Arc;
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
@ -50,7 +47,7 @@ impl Infer {
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: watch::Sender<bool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size);
|
let queue = Queue::new(requires_padding, 16, window_size);
|
||||||
@ -264,7 +261,7 @@ async fn batching_task(
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: watch::Sender<bool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -371,7 +368,7 @@ async fn prefill(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &Arc<AtomicBool>,
|
generation_health: &watch::Sender<bool>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
@ -380,7 +377,7 @@ async fn prefill(
|
|||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
let _ = generation_health.send(true);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
@ -394,7 +391,7 @@ async fn prefill(
|
|||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
let _ = generation_health.send(false);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
@ -408,7 +405,7 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &Arc<AtomicBool>,
|
generation_health: &watch::Sender<bool>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
@ -417,7 +414,7 @@ async fn decode(
|
|||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
let _ = generation_health.send(true);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
@ -430,7 +427,7 @@ async fn decode(
|
|||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
let _ = generation_health.send(false);
|
||||||
for id in batch_ids {
|
for id in batch_ids {
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
|
@ -19,11 +19,10 @@ use futures::Stream;
|
|||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
use tokio::sync::watch;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
@ -584,8 +583,8 @@ pub async fn run(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
let health_ext = Health::new(client.clone(), generation_health_receiver);
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -596,7 +595,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
generation_health,
|
generation_health_sender,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
Loading…
Reference in New Issue
Block a user