mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
back to atomicbool
This commit is contained in:
parent
3f1cc9bad7
commit
b2a5dd64c1
@ -1,7 +1,8 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use tokio::sync::watch;
|
|
||||||
|
|
||||||
// Note: Request ids and batch ids cannot collide.
|
// Note: Request ids and batch ids cannot collide.
|
||||||
const LIVENESS_ID: u64 = u64::MAX;
|
const LIVENESS_ID: u64 = u64::MAX;
|
||||||
@ -10,11 +11,11 @@ const BATCH_ID: u64 = u64::MAX;
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct Health {
|
pub(crate) struct Health {
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
generation_health: watch::Receiver<bool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Health {
|
impl Health {
|
||||||
pub(crate) fn new(client: ShardedClient, generation_health: watch::Receiver<bool>) -> Self {
|
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
generation_health,
|
generation_health,
|
||||||
@ -22,7 +23,7 @@ impl Health {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
if *self.generation_health.borrow() {
|
if self.generation_health.load(Ordering::SeqCst) {
|
||||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||||
self.client.health().await.is_ok()
|
self.client.health().await.is_ok()
|
||||||
} else {
|
} else {
|
||||||
@ -58,7 +59,10 @@ impl Health {
|
|||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
};
|
};
|
||||||
// Skips the queue
|
// Skips the queue
|
||||||
self.client.prefill(batch).await.is_ok()
|
let value = self.client.prefill(batch).await.is_ok();
|
||||||
|
// Update generation health
|
||||||
|
self.generation_health.store(value, Ordering::SeqCst);
|
||||||
|
value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,13 +4,16 @@ use crate::{Entry, Queue, Token};
|
|||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
@ -47,7 +50,7 @@ impl Infer {
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
generation_health: watch::Sender<bool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size);
|
let queue = Queue::new(requires_padding, 16, window_size);
|
||||||
@ -261,7 +264,7 @@ async fn batching_task(
|
|||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
generation_health: watch::Sender<bool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -368,7 +371,7 @@ async fn prefill(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &watch::Sender<bool>,
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
@ -377,7 +380,7 @@ async fn prefill(
|
|||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
let _ = generation_health.send(true);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
@ -391,7 +394,7 @@ async fn prefill(
|
|||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// Update health
|
// Update health
|
||||||
let _ = generation_health.send(false);
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
@ -405,33 +408,29 @@ async fn decode(
|
|||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &watch::Sender<bool>,
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
tracing::info!("Decode");
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
// Update health
|
// Update health
|
||||||
let _ = generation_health.send(true);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
tracing::info!("filter send");
|
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
tracing::info!("filter batch");
|
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
tracing::info!("Decode ended");
|
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = generation_health.send(false);
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
for id in batch_ids {
|
for id in batch_ids {
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
@ -466,13 +465,11 @@ async fn filter_batch(
|
|||||||
// Next batch is now empty
|
// Next batch is now empty
|
||||||
// Clear it from the Python shards cache
|
// Clear it from the Python shards cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
tracing::info!("clear cache");
|
|
||||||
client.clear_cache(Some(id)).await.unwrap();
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// Filter Python shard cache
|
// Filter Python shard cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
tracing::info!("filter batch call");
|
|
||||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,10 +19,11 @@ use futures::Stream;
|
|||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::watch;
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
@ -583,8 +584,8 @@ pub async fn run(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
let (generation_health_sender, generation_health_receiver) = watch::channel(false);
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
let health_ext = Health::new(client.clone(), generation_health_receiver);
|
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -595,7 +596,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
generation_health_sender,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
Loading…
Reference in New Issue
Block a user