From e7503a42401da0c32e1fe2679cfddada40581454 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Apr 2023 19:14:21 +0200 Subject: [PATCH] add store true when successful prefill/decode --- router/src/health.rs | 62 +++++++++++++++++++++++++++++++ router/src/infer.rs | 36 +++++++++--------- router/src/lib.rs | 7 +--- router/src/server.rs | 88 +++++++++++--------------------------------- 4 files changed, 104 insertions(+), 89 deletions(-) create mode 100644 router/src/health.rs diff --git a/router/src/health.rs b/router/src/health.rs new file mode 100644 index 00000000..02edf328 --- /dev/null +++ b/router/src/health.rs @@ -0,0 +1,62 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use text_generation_client::{ + Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, +}; + +#[derive(Clone, Debug)] +pub(crate) struct Health { + client: ShardedClient, + generation_health: Arc, +} + +impl Health { + pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { + Self { + client, + generation_health, + } + } + + pub(crate) async fn check(&mut self) -> bool { + if self.generation_health.load(Ordering::SeqCst) { + // Generation is healthy, we only check that the shards are answering gRPC calls + self.client.health().await.is_ok() + } else { + // Generation is unhealthy or have not sent any generation request yet + + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + watermark: false, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + // Skips the queue + let value = self.client.prefill(batch).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + value + } + } +} diff --git a/router/src/infer.rs b/router/src/infer.rs index ce5c5dce..313ec3e1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -30,8 +30,6 @@ pub struct Infer { shared: Arc, /// Inference limit limit_concurrent_requests: Arc, - /// Has done roundtrip valid run - healthy: Arc, } /// Infer shared state @@ -41,6 +39,7 @@ struct Shared { } impl Infer { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, validation: Validation, @@ -49,6 +48,7 @@ impl Infer { max_waiting_tokens: usize, max_concurrent_requests: usize, requires_padding: bool, + generation_health: Arc, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding); @@ -64,29 +64,20 @@ impl Infer { max_waiting_tokens, queue.clone(), shared.clone(), + generation_health, )); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - let healthy = Arc::new(AtomicBool::new(false)); Self { validation, queue, shared, limit_concurrent_requests: semaphore, - healthy, } } - pub(crate) fn healthy(&self) -> bool { - self.healthy.load(Ordering::SeqCst) - } - - pub(crate) fn set_healthy(&self, value: bool) { - self.healthy.store(value, Ordering::SeqCst) - } - /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip(self))] pub(crate) async fn generate_stream( @@ -255,6 +246,7 @@ async fn batching_task( max_waiting_tokens: usize, queue: Queue, shared: Arc, + generation_health: Arc, ) { // Infinite loop loop { @@ -267,7 +259,7 @@ async fn batching_task( while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_total_tokens).await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) .instrument(span) .await; let mut waiting_tokens = 1; @@ -316,9 +308,10 @@ async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; + let new_cached_batch = + prefill(&mut client, new_batch, &mut new_entries, &generation_health) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -342,7 +335,7 @@ async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries, &generation_health) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -358,6 +351,7 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; @@ -365,6 +359,8 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -377,6 +373,8 @@ async fn prefill( } // If we have an error, we discard the whole batch Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); @@ -390,6 +388,7 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); @@ -397,6 +396,8 @@ async fn decode( match client.decode(batches).await { Ok((generations, next_batch)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -409,6 +410,7 @@ async fn decode( } // If we have an error, we discard the whole batch Err(err) => { + generation_health.store(false, Ordering::SeqCst); for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } diff --git a/router/src/lib.rs b/router/src/lib.rs index bf2112a9..b2dc900a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,3 +1,4 @@ +mod health; /// Text Generation Inference Webserver mod infer; mod queue; @@ -7,7 +8,6 @@ mod validation; use infer::Infer; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; -use text_generation_client::ShardedClient; use utoipa::ToSchema; use validation::Validation; @@ -20,11 +20,6 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Clone, Debug)] -pub struct Health { - pub client: ShardedClient, -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info diff --git a/router/src/server.rs b/router/src/server.rs index a0bdf42e..1fd48963 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,10 +1,11 @@ +use crate::health::Health; /// HTTP Server logic use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info, - PrefillToken, StreamDetails, StreamResponse, Token, Validation, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, + StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -18,6 +19,8 @@ use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; @@ -86,54 +89,25 @@ async fn get_model_info(info: Extension) -> Json { get, tag = "Text Generation Inference", path = "/health", - request_body = HealthRequest, responses( (status = 200, description = "Everything is working fine"), - (status = 500, description = "Text generation inference is down", body = ErrorResponse, + (status = 503, description = "Text generation inference is down", body = ErrorResponse, example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(infer))] +#[instrument(skip(health))] /// Health check method -async fn health( - mut health: Extension, - infer: Extension, -) -> Result, (StatusCode, Json)> { - if infer.healthy() { - health.client.health().await.map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "unhealthy".to_string(), - error_type: "healthcheck".to_string(), - }), - ) - })?; - } else { - infer - .generate(GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: None, - repetition_penalty: None, - top_k: None, - top_p: None, - typical_p: None, - do_sample: false, - max_new_tokens: 1, - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: false, - seed: None, - }, - }) - .await?; - infer.set_healthy(true); +async fn health(mut health: Extension) -> Result<(), (StatusCode, Json)> { + match health.check().await { + true => Ok(()), + false => Err(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ErrorResponse { + error: "unhealthy".to_string(), + error_type: "healthcheck".to_string(), + }), + )), } - Ok(axum::Json(())) } /// Generate tokens @@ -184,27 +158,10 @@ async fn generate( // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await { - Ok(result) => result, - Err(err) => { - infer.set_healthy(false); - return Err(err)?; - } - }; + let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; (response, Some(best_of_responses)) } - _ => ( - { - match infer.generate(req.0).await { - Ok(result) => result, - Err(err) => { - infer.set_healthy(false); - return Err(err)?; - } - } - }, - None, - ), + _ => (infer.generate(req.0).await?, None), }; // Token details @@ -483,7 +440,6 @@ async fn generate_stream( // yield error Err(err) => { error = true; - infer.set_healthy(false); yield Ok(Event::from(err)); break; } @@ -595,9 +551,8 @@ pub async fn run( max_input_length, max_total_tokens, ); - let health_ext = Health { - client: client.clone(), - }; + let healthy = Arc::new(AtomicBool::new(false)); + let health_ext = Health::new(client.clone(), healthy.clone()); let infer = Infer::new( client, validation, @@ -606,6 +561,7 @@ pub async fn run( max_waiting_tokens, max_concurrent_requests, shard_info.requires_padding, + healthy, ); // Duration buckets