add store true when successful prefill/decode

This commit is contained in:
OlivierDehaene 2023-04-26 19:14:21 +02:00
parent 3b2d1a2854
commit e7503a4240
4 changed files with 104 additions and 89 deletions

62
router/src/health.rs Normal file
View File

@ -0,0 +1,62 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}

View File

@ -30,8 +30,6 @@ pub struct Infer {
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Has done roundtrip valid run
healthy: Arc<AtomicBool>,
}
/// Infer shared state
@ -41,6 +39,7 @@ struct Shared {
}
impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
validation: Validation,
@ -49,6 +48,7 @@ impl Infer {
max_waiting_tokens: usize,
max_concurrent_requests: usize,
requires_padding: bool,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding);
@ -64,29 +64,20 @@ impl Infer {
max_waiting_tokens,
queue.clone(),
shared.clone(),
generation_health,
));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let healthy = Arc::new(AtomicBool::new(false));
Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
healthy,
}
}
pub(crate) fn healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
pub(crate) fn set_healthy(&self, value: bool) {
self.healthy.store(value, Ordering::SeqCst)
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream(
@ -255,6 +246,7 @@ async fn batching_task(
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
generation_health: Arc<AtomicBool>,
) {
// Infinite loop
loop {
@ -267,7 +259,7 @@ async fn batching_task(
while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -316,7 +308,8 @@ async fn batching_task(
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
@ -342,7 +335,7 @@ async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
@ -358,6 +351,7 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> {
let start_time = Instant::now();
let batch_id = batch.id;
@ -365,6 +359,8 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
@ -377,6 +373,8 @@ async fn prefill(
}
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -390,6 +388,7 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -397,6 +396,8 @@ async fn decode(
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
@ -409,6 +410,7 @@ async fn decode(
}
// If we have an error, we discard the whole batch
Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}

View File

@ -1,3 +1,4 @@
mod health;
/// Text Generation Inference Webserver
mod infer;
mod queue;
@ -7,7 +8,6 @@ mod validation;
use infer::Infer;
use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use text_generation_client::ShardedClient;
use utoipa::ToSchema;
use validation::Validation;
@ -20,11 +20,6 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Debug)]
pub struct Health {
pub client: ShardedClient,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info

View File

@ -1,10 +1,11 @@
use crate::health::Health;
/// HTTP Server logic
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Health, HubModelInfo, Infer, Info,
PrefillToken, StreamDetails, StreamResponse, Token, Validation,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -18,6 +19,8 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
use tokio::signal;
@ -86,54 +89,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
get,
tag = "Text Generation Inference",
path = "/health",
request_body = HealthRequest,
responses(
(status = 200, description = "Everything is working fine"),
(status = 500, description = "Text generation inference is down", body = ErrorResponse,
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
)]
#[instrument(skip(infer))]
#[instrument(skip(health))]
/// Health check method
async fn health(
mut health: Extension<Health>,
infer: Extension<Infer>,
) -> Result<Json<()>, (StatusCode, Json<ErrorResponse>)> {
if infer.healthy() {
health.client.health().await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await {
true => Ok(()),
false => Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)
})?;
} else {
infer
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
infer.set_healthy(true);
)),
}
Ok(axum::Json(()))
}
/// Generate tokens
@ -184,27 +158,10 @@ async fn generate(
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = match infer.generate_best_of(req.0, best_of).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
};
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
(response, Some(best_of_responses))
}
_ => (
{
match infer.generate(req.0).await {
Ok(result) => result,
Err(err) => {
infer.set_healthy(false);
return Err(err)?;
}
}
},
None,
),
_ => (infer.generate(req.0).await?, None),
};
// Token details
@ -483,7 +440,6 @@ async fn generate_stream(
// yield error
Err(err) => {
error = true;
infer.set_healthy(false);
yield Ok(Event::from(err));
break;
}
@ -595,9 +551,8 @@ pub async fn run(
max_input_length,
max_total_tokens,
);
let health_ext = Health {
client: client.clone(),
};
let healthy = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), healthy.clone());
let infer = Infer::new(
client,
validation,
@ -606,6 +561,7 @@ pub async fn run(
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
healthy,
);
// Duration buckets