From 98cfc9e70c92a777f12d3f0b8f6ab20f7bc065e7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sun, 9 Apr 2023 19:41:24 +0200 Subject: [PATCH] fix(router): use buckets for metrics histograms --- router/src/infer.rs | 6 ++- router/src/server.rs | 97 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 5eafc3e9..2df9c5be 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -331,11 +331,12 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; + metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); match client.prefill(batch).await { Ok((generations, next_batch)) => { send_generations(generations, entries); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); next_batch } @@ -356,11 +357,12 @@ async fn decode( entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); + metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); match client.decode(batches).await { Ok((generations, next_batch)) => { send_generations(generations, entries); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); next_batch } diff --git a/router/src/server.rs b/router/src/server.rs index f7850053..e1b402a6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -14,7 +14,7 @@ use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use futures::Stream; -use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; use text_generation_client::ShardedClient; @@ -120,6 +120,7 @@ async fn generate( ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); + metrics::increment_counter!("tgi_request_count"); let compute_characters = req.0.inputs.chars().count(); let mut add_prompt = None; @@ -185,6 +186,14 @@ async fn generate( let inference_time = Instant::now() - response.start; let time_per_token = inference_time / response.generated_text.generated_tokens; + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", response.generated_text.seed)); + // Headers let mut headers = HeaderMap::new(); headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); @@ -217,22 +226,22 @@ async fn generate( time_per_token.as_millis().to_string().parse().unwrap(), ); - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("validation_time", format!("{validation_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - span.record("time_per_token", format!("{time_per_token:?}")); - span.record("seed", format!("{:?}", response.generated_text.seed)); - tracing::info!("Output: {}", response.generated_text.text); - // Metrics metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time); - metrics::histogram!("tgi_request_validation_duration", validation_time); - metrics::histogram!("tgi_request_queue_duration", queue_time); - metrics::histogram!("tgi_request_inference_duration", inference_time); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); + metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "tgi_request_validation_duration", + validation_time.as_secs_f64() + ); + metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "tgi_request_inference_duration", + inference_time.as_secs_f64() + ); + metrics::histogram!( + "tgi_request_mean_time_per_token_duration", + time_per_token.as_secs_f64() + ); metrics::histogram!( "tgi_request_generated_tokens", response.generated_text.generated_tokens as f64 @@ -244,6 +253,8 @@ async fn generate( output_text = prompt + &output_text; } + tracing::info!("Output: {}", output_text); + let response = GenerateResponse { generated_text: output_text, details, @@ -294,6 +305,7 @@ async fn generate_stream( ) { let span = tracing::Span::current(); let start_time = Instant::now(); + metrics::increment_counter!("tgi_request_count"); let compute_characters = req.0.inputs.chars().count(); @@ -368,15 +380,14 @@ async fn generate_stream( span.record("inference_time", format!("{inference_time:?}")); span.record("time_per_token", format!("{time_per_token:?}")); span.record("seed", format!("{:?}", generated_text.seed)); - tracing::info!(parent: &span, "Output: {}", generated_text.text); // Metrics metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time); - metrics::histogram!("tgi_request_validation_duration", validation_time); - metrics::histogram!("tgi_request_queue_duration", queue_time); - metrics::histogram!("tgi_request_inference_duration", inference_time); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); + metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); // StreamResponse @@ -387,6 +398,8 @@ async fn generate_stream( output_text = prompt + &output_text; } + tracing::info!(parent: &span, "Output: {}", output_text); + let stream_token = StreamResponse { token, generated_text: Some(output_text), @@ -513,8 +526,48 @@ pub async fn run( max_concurrent_requests, ); + // Duration buckets + let duration_matcher = Matcher::Suffix(String::from("duration")); + let n_duration_buckets = 35; + let mut duration_buckets = Vec::with_capacity(n_duration_buckets); + // Minimum duration in seconds + let mut value = 0.0001; + for _ in 0..n_duration_buckets { + // geometric sequence + value *= 1.5; + duration_buckets.push(value); + } + // Input Length buckets + let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length")); + let input_length_buckets: Vec = (0..100) + .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) + .collect(); + // Generated tokens buckets + let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens")); + let generated_tokens_buckets: Vec = (0..100) + .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) + .collect(); + // Input Length buckets + let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens")); + let max_new_tokens_buckets: Vec = (0..100) + .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) + .collect(); + // Batch size buckets + let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); + let batch_size_buckets: Vec = (0..max_batch_size).map(|x| (x + 1) as f64).collect(); + // Prometheus handler - let builder = PrometheusBuilder::new(); + let builder = PrometheusBuilder::new() + .set_buckets_for_metric(duration_matcher, &duration_buckets) + .unwrap() + .set_buckets_for_metric(input_length_matcher, &input_length_buckets) + .unwrap() + .set_buckets_for_metric(generated_tokens_matcher, &generated_tokens_buckets) + .unwrap() + .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) + .unwrap() + .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) + .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder");