diff --git a/Cargo.lock b/Cargo.lock index e8a28bf9..944047e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.20" @@ -806,6 +817,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash", +] [[package]] name = "heck" @@ -1093,6 +1107,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "mach" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" +dependencies = [ + "libc", +] + [[package]] name = "macro_rules_attribute" version = "0.1.3" @@ -1139,6 +1162,64 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metrics" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849" +dependencies = [ + "ahash", + "metrics-macros", + "portable-atomic", +] + +[[package]] +name = "metrics-exporter-prometheus" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70" +dependencies = [ + "hyper", + "indexmap", + "ipnet", + "metrics", + "metrics-util", + "parking_lot", + "portable-atomic", + "quanta", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "metrics-macros" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "metrics-util" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown", + "metrics", + "num_cpus", + "parking_lot", + "portable-atomic", + "quanta", + "sketches-ddsketch", +] + [[package]] name = "mime" version = "0.3.16" @@ -1514,6 +1595,12 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +[[package]] +name = "portable-atomic" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1618,6 +1705,22 @@ dependencies = [ "prost", ] +[[package]] +name = "quanta" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27" +dependencies = [ + "crossbeam-utils", + "libc", + "mach", + "once_cell", + "raw-cpuid", + "wasi 0.10.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.23" @@ -1657,6 +1760,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "10.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c307f7aacdbab3f0adee67d52739a1d71112cc068d6fab169ddeb18e48877fad" +dependencies = [ + "bitflags", +] + [[package]] name = "rayon" version = "1.6.1" @@ -1980,6 +2092,12 @@ dependencies = [ "libc", ] +[[package]] +name = "sketches-ddsketch" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7" + [[package]] name = "slab" version = "0.4.7" @@ -2143,6 +2261,8 @@ dependencies = [ "axum-tracing-opentelemetry", "clap 4.1.4", "futures", + "metrics", + "metrics-exporter-prometheus", "nohash-hasher", "opentelemetry", "opentelemetry-otlp", diff --git a/router/Cargo.toml b/router/Cargo.toml index f1ace790..9ac500c9 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0" text-generation-client = { path = "client" } clap = { version = "4.1.4", features = ["derive", "env"] } futures = "0.3.26" +metrics = "0.20.1" +metrics-exporter-prometheus = { version = "0.11.0", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" diff --git a/router/src/infer.rs b/router/src/infer.rs index 4e368492..dc0df50a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError}; use crate::GenerateRequest; use crate::{Entry, Queue, Token}; use nohash_hasher::IntMap; -use std::future::Future; use std::sync::Arc; use text_generation_client::{ Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, @@ -81,6 +80,7 @@ impl Infer { .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); tracing::error!("{err}"); err })?; @@ -172,6 +172,7 @@ impl Infer { }) } else { let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); tracing::error!("{err}"); Err(err) } @@ -201,7 +202,7 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { - let mut cached_batch = wrap_future(client.prefill(batch), &mut entries) + let mut cached_batch = prefill(&mut client, batch, &mut entries) .instrument(span) .await; let mut waiting_tokens = 1; @@ -212,6 +213,7 @@ async fn batching_task( // Get current batch info let batch_size = batch.size; let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size", batch_size as f64); // If the current batch is too small, we try to add more requests to it if batch_size <= limit_min_batch_size { @@ -241,10 +243,9 @@ async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - wrap_future(client.prefill(new_batch), &mut new_entries) - .instrument(span) - .await; + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -268,29 +269,59 @@ async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = wrap_future(client.decode(batches), &mut entries) + cached_batch = decode(&mut client, batches, &mut entries) .instrument(next_batch_span) .await; waiting_tokens += 1; } + metrics::gauge!("tgi_batch_current_size", 0.0); } } } -/// Wrap a future inside a match statement to handle errors and send the responses to Infer #[instrument(skip_all)] -async fn wrap_future( - future: impl Future, Option), ClientError>>, +async fn prefill( + client: &mut ShardedClient, + batch: Batch, entries: &mut IntMap, ) -> Option { - match future.await { + let start_time = Instant::now(); + + match client.prefill(batch).await { Ok((generations, next_batch)) => { send_generations(generations, entries); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); next_batch } // If we have an error, we discard the whole batch Err(err) => { send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + + match client.decode(batches).await { + Ok((generations, next_batch)) => { + send_generations(generations, entries); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); None } } @@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); + metrics::increment_counter!("tgi_request_failure", "err" => "generation"); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/queue.rs b/router/src/queue.rs index b155a1af..8962aaec 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -132,6 +132,7 @@ impl State { // Push entry in the queue self.entries.push((self.next_id, entry)); self.next_id += 1; + metrics::increment_gauge!("tgi_queue_size", 1.0); } // Get the next batch @@ -190,6 +191,8 @@ impl State { // Increment batch id self.next_batch_id += 1; + metrics::gauge!("tgi_queue_size", self.entries.len() as f64); + metrics::histogram!("tgi_batch_next_size", batch.size as f64); Some((batch_entries, batch, next_batch_span)) } } diff --git a/router/src/server.rs b/router/src/server.rs index 19af1e78..48affa46 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,6 +12,7 @@ use axum::routing::{get, post}; use axum::{Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use futures::Stream; +use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; use text_generation_client::ShardedClient; @@ -57,14 +58,14 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json "incomplete"); tracing::error!("{err}"); yield Ok(Event::from(err)) } @@ -287,6 +310,17 @@ async fn generate_stream( Sse::new(stream).keep_alive(KeepAlive::default()) } +/// Prometheus metrics scrape endpoint +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/metrics", + responses((status = 200, description = "Prometheus Metrics", body = String)) +)] +async fn metrics(prom_handle: Extension) -> String { + prom_handle.render() +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -307,6 +341,7 @@ pub async fn run( paths( generate, generate_stream, + metrics, ), components( schemas( @@ -350,6 +385,12 @@ pub async fn run( max_concurrent_requests, ); + // Prometheus handler + let builder = PrometheusBuilder::new(); + let prom_handle = builder + .install_recorder() + .expect("failed to install metrics recorder"); + // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) @@ -359,6 +400,8 @@ pub async fn run( .route("/", get(health)) .route("/health", get(health)) .layer(Extension(infer)) + .route("/metrics", get(metrics)) + .layer(Extension(prom_handle)) .layer(opentelemetry_tracing_layer()); // Run server diff --git a/router/src/validation.rs b/router/src/validation.rs index 50d090cd..e63cfb47 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -13,7 +13,7 @@ use tracing::{instrument, Span}; #[derive(Debug, Clone)] pub struct Validation { /// Channel to communicate with the background validation task - sender: mpsc::Sender, + sender: mpsc::UnboundedSender, } impl Validation { @@ -25,7 +25,7 @@ impl Validation { max_total_tokens: usize, ) -> Self { // Create channel - let (validation_sender, validation_receiver) = mpsc::channel(128); + let (validation_sender, validation_receiver) = mpsc::unbounded_channel(); // Launch background validation task tokio::spawn(validation_task( @@ -54,7 +54,6 @@ impl Validation { // Unwrap is safe here self.sender .send((request, sender, Span::current())) - .await .unwrap(); // Await on response channel // Unwrap is safe here @@ -70,7 +69,7 @@ async fn validation_task( max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, - mut receiver: mpsc::Receiver, + mut receiver: mpsc::UnboundedReceiver, ) { let mut workers_senders = Vec::with_capacity(workers); @@ -131,6 +130,7 @@ fn validation_worker( &mut rng, ) .map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); err }), @@ -214,6 +214,7 @@ fn validate( Ok(encoding) => { let input_length = encoding.len(); let total_tokens = input_length + max_new_tokens as usize; + if input_length > max_input_length { Err(ValidationError::InputLength(max_input_length, input_length)) } else if total_tokens > max_total_tokens { @@ -237,6 +238,9 @@ fn validate( stop_sequences, }; + metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + Ok(ValidGenerateRequest { inputs: request.inputs, input_length: input_length as u32, diff --git a/server/text_generation/utils/convert.py b/server/text_generation/utils/convert.py index 30144f0c..437e2308 100644 --- a/server/text_generation/utils/convert.py +++ b/server/text_generation/utils/convert.py @@ -49,6 +49,8 @@ def convert_file(pt_file: Path, st_file: Path): """ Convert a pytorch file to a safetensors file """ + logger.info(f"Convert {pt_file} to {st_file}.") + pt_state = torch.load(pt_file, map_location="cpu") if "state_dict" in pt_state: pt_state = pt_state["state_dict"] diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py index 372df306..713488b5 100644 --- a/server/text_generation/utils/hub.py +++ b/server/text_generation/utils/hub.py @@ -132,9 +132,9 @@ def download_weights( local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") - return local_file + return Path(local_file) - logger.info(f"Starting {filename} download.") + logger.info(f"Download file: {filename}") start_time = time.time() local_file = hf_hub_download( filename=filename, @@ -143,7 +143,7 @@ def download_weights( local_files_only=False, ) logger.info( - f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." ) return Path(local_file)