From b3cc379550dc5e31a707f50e5fd86ded2f9b9a4e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Feb 2023 19:24:09 +0100 Subject: [PATCH] add tracing to rust router --- Cargo.lock | 273 +++++++++++++++++++- Cargo.toml | 1 + router/Cargo.toml | 5 +- router/client/Cargo.toml | 1 + router/client/src/client.rs | 36 +-- router/client/src/sharded_client.rs | 4 + router/grpc-metadata/Cargo.toml | 10 + router/grpc-metadata/src/lib.rs | 62 +++++ router/src/infer.rs | 40 ++- router/src/main.rs | 77 +++++- router/src/queue.rs | 39 ++- router/src/server.rs | 4 +- router/src/validation.rs | 18 +- server/text_generation/cli.py | 6 +- server/text_generation/models/seq2seq_lm.py | 165 ++++++------ server/text_generation/server.py | 9 +- server/text_generation/tracing.py | 33 +-- 17 files changed, 615 insertions(+), 168 deletions(-) create mode 100644 router/grpc-metadata/Cargo.toml create mode 100644 router/grpc-metadata/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 0a537115..bfa62eb1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,6 +424,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "derive_builder" version = "0.9.0" @@ -717,6 +730,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +[[package]] +name = "grpc-metadata" +version = "0.1.0" +dependencies = [ + "opentelemetry", + "tonic 0.6.2", + "tracing", + "tracing-opentelemetry", +] + [[package]] name = "h2" version = "0.3.15" @@ -1010,6 +1033,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata", +] + [[package]] name = "matchit" version = "0.7.0" @@ -1231,6 +1263,86 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "opentelemetry" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" +dependencies = [ + "opentelemetry_api", + "opentelemetry_sdk", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde" +dependencies = [ + "async-trait", + "futures", + "futures-util", + "http", + "opentelemetry", + "opentelemetry-proto", + "prost 0.11.6", + "thiserror", + "tokio", + "tonic 0.8.3", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28" +dependencies = [ + "futures", + "futures-util", + "opentelemetry", + "prost 0.11.6", + "tonic 0.8.3", + "tonic-build 0.8.4", +] + +[[package]] +name = "opentelemetry_api" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" +dependencies = [ + "fnv", + "futures-channel", + "futures-util", + "indexmap", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" +dependencies = [ + "async-trait", + "crossbeam-channel", + "dashmap", + "fnv", + "futures-channel", + "futures-executor", + "futures-util", + "once_cell", + "opentelemetry_api", + "percent-encoding", + "rand", + "thiserror", + "tokio", + "tokio-stream", +] + [[package]] name = "os_str_bytes" version = "6.3.1" @@ -1332,6 +1444,16 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "prettyplease" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1372,7 +1494,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.9.0", +] + +[[package]] +name = "prost" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698" +dependencies = [ + "bytes", + "prost-derive 0.11.6", ] [[package]] @@ -1388,13 +1520,35 @@ dependencies = [ "log", "multimap", "petgraph", - "prost", - "prost-types", + "prost 0.9.0", + "prost-types 0.9.0", "regex", "tempfile", "which", ] +[[package]] +name = "prost-build" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e" +dependencies = [ + "bytes", + "heck 0.4.0", + "itertools 0.10.5", + "lazy_static", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost 0.11.6", + "prost-types 0.11.6", + "regex", + "syn", + "tempfile", + "which", +] + [[package]] name = "prost-derive" version = "0.9.0" @@ -1408,6 +1562,19 @@ dependencies = [ "syn", ] +[[package]] +name = "prost-derive" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d" +dependencies = [ + "anyhow", + "itertools 0.10.5", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "prost-types" version = "0.9.0" @@ -1415,7 +1582,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a" dependencies = [ "bytes", - "prost", + "prost 0.9.0", +] + +[[package]] +name = "prost-types" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788" +dependencies = [ + "bytes", + "prost 0.11.6", ] [[package]] @@ -1523,6 +1700,15 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax", +] + [[package]] name = "regex-syntax" version = "0.6.28" @@ -1891,11 +2077,12 @@ name = "text-generation-client" version = "0.2.1" dependencies = [ "futures", - "prost", + "grpc-metadata", + "prost 0.9.0", "thiserror", "tokio", - "tonic", - "tonic-build", + "tonic 0.6.2", + "tonic-build 0.6.2", "tower", "tracing", "tracing-error", @@ -1925,6 +2112,8 @@ dependencies = [ "clap 4.0.22", "futures", "nohash-hasher", + "opentelemetry", + "opentelemetry-otlp", "parking_lot", "rand", "serde", @@ -1935,6 +2124,7 @@ dependencies = [ "tokio", "tokio-stream", "tracing", + "tracing-opentelemetry", "tracing-subscriber", "utoipa", "utoipa-swagger-ui", @@ -2148,8 +2338,8 @@ dependencies = [ "hyper-timeout", "percent-encoding", "pin-project", - "prost", - "prost-derive", + "prost 0.9.0", + "prost-derive 0.9.0", "tokio", "tokio-stream", "tokio-util 0.6.10", @@ -2160,6 +2350,38 @@ dependencies = [ "tracing-futures", ] +[[package]] +name = "tonic" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost 0.11.6", + "prost-derive 0.11.6", + "tokio", + "tokio-stream", + "tokio-util 0.7.4", + "tower", + "tower-layer", + "tower-service", + "tracing", + "tracing-futures", +] + [[package]] name = "tonic-build" version = "0.6.2" @@ -2167,7 +2389,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9403f1bafde247186684b230dc6f38b5cd514584e8bec1dd32514be4745fa757" dependencies = [ "proc-macro2", - "prost-build", + "prost-build 0.9.0", + "quote", + "syn", +] + +[[package]] +name = "tonic-build" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build 0.11.6", "quote", "syn", ] @@ -2288,6 +2523,20 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" +dependencies = [ + "once_cell", + "opentelemetry", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", +] + [[package]] name = "tracing-serde" version = "0.1.3" @@ -2304,12 +2553,16 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "serde", "serde_json", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", "tracing-serde", diff --git a/Cargo.toml b/Cargo.toml index 3720af32..b3bd5dce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "router", "router/client", + "router/grpc-metadata", "launcher" ] diff --git a/router/Cargo.toml b/router/Cargo.toml index 186f97c5..6c35f443 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -19,6 +19,8 @@ text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" nohash-hasher = "0.2.0" +opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.11.0" parking_lot = "0.12.1" rand = "0.8.5" serde = "1.0.145" @@ -28,7 +30,8 @@ tokenizers = "0.13.0" tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.11" tracing = "0.1.36" -tracing-subscriber = { version = "0.3.15", features = ["json"] } +tracing-opentelemetry = "0.18.0" +tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] } utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index 4cddf81b..25c17b2d 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] futures = "^0.3" +grpc-metadata = { path = "../grpc-metadata" } prost = "^0.9" thiserror = "^1.0" tokio = { version = "^1.21", features = ["sync"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 77a43110..199182f4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -2,8 +2,9 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::*; use crate::Result; +use grpc_metadata::InjectTelemetryContext; use tonic::transport::{Channel, Uri}; -use tracing::*; +use tracing::instrument; /// Text Generation Inference gRPC client #[derive(Clone)] @@ -38,12 +39,8 @@ impl Client { /// Returns a list of uris or unix sockets of all shards #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { - let request = tonic::Request::new(ServiceDiscoveryRequest {}); - let response = self - .stub - .service_discovery(request) - .instrument(info_span!("service_discovery")) - .await?; + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await?; let urls = response .into_inner() .urls @@ -60,11 +57,8 @@ impl Client { /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self) -> Result<()> { - let request = tonic::Request::new(ClearCacheRequest {}); - self.stub - .clear_cache(request) - .instrument(info_span!("clear_cache")) - .await?; + let request = tonic::Request::new(ClearCacheRequest {}).inject_context(); + self.stub.clear_cache(request).await?; Ok(()) } @@ -74,13 +68,8 @@ impl Client { /// and the next cached batch #[instrument(skip(self))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); - let response = self - .stub - .prefill(request) - .instrument(info_span!("prefill")) - .await? - .into_inner(); + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); Ok((response.generations, response.batch)) } @@ -93,13 +82,8 @@ impl Client { &mut self, batches: Vec, ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(DecodeRequest { batches }); - let response = self - .stub - .decode(request) - .instrument(info_span!("decode")) - .await? - .into_inner(); + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); Ok((response.generations, response.batch)) } } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 56335f92..f77425cd 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation}; use futures::future::join_all; use futures::future::select_all; use tonic::transport::Uri; +use tracing::instrument; /// Text Generation Inference gRPC multi client pub struct ShardedClient { @@ -38,6 +39,7 @@ impl ShardedClient { } /// Clear the past generations cache + #[instrument(skip(self))] pub async fn clear_cache(&mut self) -> Result<()> { let futures: Vec<_> = self .clients @@ -51,6 +53,7 @@ impl ShardedClient { /// /// Returns Generation for each request in batch /// and the next cached batch + #[instrument(skip(self))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients @@ -66,6 +69,7 @@ impl ShardedClient { /// /// Returns Generation for each request in batches /// and the next cached batch + #[instrument(skip(self))] pub async fn decode( &mut self, batches: Vec, diff --git a/router/grpc-metadata/Cargo.toml b/router/grpc-metadata/Cargo.toml new file mode 100644 index 00000000..7367fb26 --- /dev/null +++ b/router/grpc-metadata/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "grpc-metadata" +version = "0.1.0" +edition = "2021" + +[dependencies] +opentelemetry = "0.18.0" +tonic = "^0.6" +tracing = "^0.1" +tracing-opentelemetry = "0.18.0" diff --git a/router/grpc-metadata/src/lib.rs b/router/grpc-metadata/src/lib.rs new file mode 100644 index 00000000..cf2c7ed2 --- /dev/null +++ b/router/grpc-metadata/src/lib.rs @@ -0,0 +1,62 @@ +//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request. +//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples + +use opentelemetry::global; +use opentelemetry::propagation::{Extractor, Injector}; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +/// Extract context metadata from a gRPC request's metadata +struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap); + +impl<'a> Extractor for MetadataExtractor<'a> { + /// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None + fn get(&self, key: &str) -> Option<&str> { + self.0.get(key).and_then(|metadata| metadata.to_str().ok()) + } + + /// Collect all the keys from the MetadataMap. + fn keys(&self) -> Vec<&str> { + self.0 + .keys() + .map(|key| match key { + tonic::metadata::KeyRef::Ascii(v) => v.as_str(), + tonic::metadata::KeyRef::Binary(v) => v.as_str(), + }) + .collect::>() + } +} + +/// Inject context in the metadata of a gRPC request. +struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); + +impl<'a> Injector for MetadataInjector<'a> { + /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs + fn set(&mut self, key: &str, value: String) { + if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { + if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) { + self.0.insert(key, val); + } + } + } +} + +/// Get a context from the global context and inject the span into a gRPC request's metadata. +fn inject(metadata: &mut tonic::metadata::MetadataMap) { + global::get_text_map_propagator(|propagator| { + propagator.inject_context( + &tracing::Span::current().context(), + &mut MetadataInjector(metadata), + ) + }) +} + +pub trait InjectTelemetryContext { + fn inject_context(self) -> Self; +} + +impl InjectTelemetryContext for tonic::Request { + fn inject_context(mut self) -> Self { + inject(self.metadata_mut()); + self + } +} diff --git a/router/src/infer.rs b/router/src/infer.rs index 159b7ca7..bb4aa01e 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError}; use crate::GenerateRequest; use crate::{Entry, Queue, Token}; use nohash_hasher::IntMap; +use opentelemetry::trace::TraceContextExt; use std::future::Future; use std::sync::Arc; use text_generation_client::{ @@ -13,7 +14,8 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; -use tracing::instrument; +use tracing::{info_span, instrument, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// Inference struct #[derive(Clone)] @@ -69,6 +71,7 @@ impl Infer { } /// Add a new request to the queue and return a stream of InferStreamResponse + #[instrument(skip(self))] pub(crate) async fn generate_stream( &self, request: GenerateRequest, @@ -87,6 +90,8 @@ impl Infer { self.queue.append(Entry { request: valid_request, response_tx, + parent_span: info_span!("entry"), + batch_span: None, time: Instant::now(), batch_time: None, _permit: permit, @@ -101,6 +106,7 @@ impl Infer { } /// Add a new request to the queue and return a InferResponse + #[instrument(skip(self))] pub(crate) async fn generate( &self, request: GenerateRequest, @@ -169,7 +175,6 @@ impl Infer { /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server -#[instrument(skip(client, queue, shared))] async fn batching_task( mut client: ShardedClient, max_batch_size: usize, @@ -188,8 +193,10 @@ async fn batching_task( // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the queue - while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await { - let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; + while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { + let mut cached_batch = wrap_future(client.prefill(batch), &mut entries) + .instrument(span) + .await; let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until @@ -210,13 +217,15 @@ async fn batching_task( }; // Try to get a new batch - if let Some((mut new_entries, new_batch)) = queue + if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_batch_size - batch_size as usize) .await { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - wrap_future(client.prefill(new_batch), &mut new_entries).await; + wrap_future(client.prefill(new_batch), &mut new_entries) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -226,8 +235,20 @@ async fn batching_task( } } } + let next_batch_span = info_span!("batch"); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span for this entry/batch tuple + let entry_batch_span = info_span!(parent: &entry.parent_span, "infer"); + // Add link to span + entry_batch_span + .add_link(next_batch_span.context().span().span_context().clone()); + // Update entry + entry.batch_span = Some(entry_batch_span); + }); - cached_batch = wrap_future(client.decode(batches), &mut entries).await; + cached_batch = wrap_future(client.decode(batches), &mut entries) + .instrument(next_batch_span) + .await; waiting_tokens += 1; } } @@ -235,6 +256,7 @@ async fn batching_task( } /// Wrap a future inside a match statement to handle errors and send the responses to Infer +#[instrument(skip(future))] async fn wrap_future( future: impl Future, Option), ClientError>>, entries: &mut IntMap, @@ -253,6 +275,7 @@ async fn wrap_future( } /// Send errors to Infer for all `entries` +#[instrument] fn send_error(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. @@ -264,6 +287,7 @@ fn send_error(error: ClientError, entries: &mut IntMap) { } /// Send one or multiple `InferStreamResponse` to Infer for all `entries` +#[instrument] fn send_generations(generations: Vec, entries: &mut IntMap) { generations.into_iter().for_each(|generation| { // Get entry @@ -272,6 +296,8 @@ fn send_generations(generations: Vec, entries: &mut IntMap, } fn main() -> Result<(), std::io::Error> { @@ -43,14 +53,9 @@ fn main() -> Result<(), std::io::Error> { tokenizer_name, validation_workers, json_output, + otlp_endpoint, } = args; - if json_output { - tracing_subscriber::fmt().json().init(); - } else { - tracing_subscriber::fmt().compact().init(); - } - if validation_workers == 0 { panic!("validation_workers must be > 0"); } @@ -67,6 +72,8 @@ fn main() -> Result<(), std::io::Error> { .build() .unwrap() .block_on(async { + init_logging(otlp_endpoint, json_output); + // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await @@ -96,3 +103,57 @@ fn main() -> Result<(), std::io::Error> { Ok(()) }) } + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +fn init_logging(otlp_endpoint: Option, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer + .json() + .flatten_event(true) + .with_span_list(false) + .boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = + opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config(trace::config().with_resource(Resource::new(vec![ + KeyValue::new("service.name", "text-generation-inference.router"), + ]))) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + }; + } + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} diff --git a/router/src/queue.rs b/router/src/queue.rs index 2aaf93b1..cd9bb450 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -2,11 +2,14 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; +use opentelemetry::trace::TraceContextExt; use std::cmp::min; use text_generation_client::{Batch, Request}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; +use tracing::{info_span, instrument, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// Queue entry #[derive(Debug)] @@ -15,6 +18,10 @@ pub(crate) struct Entry { pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: UnboundedSender>, + /// Request Span + pub parent_span: Span, + /// Batch Span + pub batch_span: Option, /// Instant when this entry was created pub time: Instant, /// Instant when this entry was added to a batch @@ -42,13 +49,17 @@ impl Queue { } /// Append an entry to the queue + #[instrument(skip(self))] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state // Unwrap is safe here - self.queue_sender.send(QueueCommand::Append(entry)).unwrap(); + self.queue_sender + .send(QueueCommand::Append(entry, Span::current())) + .unwrap(); } // Get the next batch + #[instrument(skip(self))] pub(crate) async fn next_batch( &self, min_size: Option, @@ -63,6 +74,7 @@ impl Queue { min_size, max_size, response_sender, + span: Span::current(), }) .unwrap(); // Await on response channel @@ -77,15 +89,16 @@ async fn queue_task(mut receiver: UnboundedReceiver) { while let Some(cmd) = receiver.recv().await { match cmd { - QueueCommand::Append(entry) => state.append(entry), + QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)), QueueCommand::NextBatch { min_size, max_size, response_sender, - } => { + span, + } => span.in_scope(|| { let next_batch = state.next_batch(min_size, max_size); response_sender.send(next_batch).unwrap_or(()); - } + }), } } } @@ -131,6 +144,7 @@ impl State { } } + let next_batch_span = info_span!("batch"); let next_batch_size = min(self.entries.len(), max_size); let mut batch_requests = Vec::with_capacity(next_batch_size); @@ -141,6 +155,13 @@ impl State { self.entries .drain(..next_batch_size) .for_each(|(id, mut entry)| { + // Create a new span for this entry/batch tuple + let entry_batch_span = info_span!(parent: &entry.parent_span, "infer"); + // Add link to span + entry_batch_span.add_link(next_batch_span.context().span().span_context().clone()); + // Update entry + entry.batch_span = Some(entry_batch_span); + batch_requests.push(Request { id, inputs: entry.request.inputs.clone(), @@ -162,19 +183,20 @@ impl State { // Increment batch id self.next_batch_id += 1; - Some((batch_entries, batch)) + Some((batch_entries, batch, next_batch_span)) } } -type NextBatch = (IntMap, Batch); +type NextBatch = (IntMap, Batch, Span); #[derive(Debug)] enum QueueCommand { - Append(Entry), + Append(Entry, Span), NextBatch { min_size: Option, max_size: usize, response_sender: oneshot::Sender>, + span: Span, }, } @@ -184,6 +206,7 @@ mod tests { use std::sync::Arc; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use tokio::sync::{mpsc, Semaphore}; + use tracing::info_span; fn default_entry() -> Entry { let semaphore = Arc::new(Semaphore::new(1)); @@ -208,6 +231,8 @@ mod tests { }, }, response_tx, + parent_span: info_span!("entry"), + batch_span: None, time: Instant::now(), batch_time: None, _permit: permit, diff --git a/router/src/server.rs b/router/src/server.rs index dffdf155..628911ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -18,7 +18,7 @@ use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; use tokio_stream::StreamExt; -use tracing::instrument; +use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -197,7 +197,7 @@ async fn generate_stream( let mut error = false; let details = req.0.parameters.details; - match infer.generate_stream(req.0).await { + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { Ok(mut response_stream) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { diff --git a/router/src/validation.rs b/router/src/validation.rs index 3cca48e9..42eb9b12 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; +use tracing::{instrument, Span}; const MAX_MAX_NEW_TOKENS: u32 = 512; const MAX_STOP_SEQUENCES: usize = 4; @@ -36,6 +37,7 @@ impl Validation { } /// Validate a payload and get the number of tokens in the input + #[instrument(skip(self))] pub(crate) async fn validate( &self, request: GenerateRequest, @@ -44,7 +46,10 @@ impl Validation { let (sender, receiver) = oneshot::channel(); // Send request to the background validation task // Unwrap is safe here - self.sender.send((request, sender)).await.unwrap(); + self.sender + .send((request, sender, Span::current())) + .await + .unwrap(); // Await on response channel // Unwrap is safe here receiver.await.unwrap() @@ -97,10 +102,12 @@ fn validation_worker( let mut rng = rand::thread_rng(); // Loop over requests - while let Some((request, response_tx)) = receiver.blocking_recv() { - response_tx - .send(validate(request, &tokenizer, max_input_length, &mut rng)) - .unwrap_or(()) + while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() { + parent_span.in_scope(|| { + response_tx + .send(validate(request, &tokenizer, max_input_length, &mut rng)) + .unwrap_or(()) + }) } } @@ -203,6 +210,7 @@ fn validate( type ValidationRequest = ( GenerateRequest, oneshot::Sender>, + Span, ); #[derive(Debug)] diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index 47aebd16..ba3a6a23 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -4,6 +4,7 @@ import typer from pathlib import Path from loguru import logger +from typer import Argument from typing import Optional from text_generation import server, utils @@ -19,9 +20,9 @@ def serve( sharded: bool = False, quantize: bool = False, uds_path: Path = "/tmp/text-generation", - otlp_endpoint: Optional[str] = None, logger_level: str = "INFO", json_output: bool = False, + otlp_endpoint: Optional[str] = Argument(None, envvar="OTLP_ENDPOINT"), ): if sharded: assert ( @@ -49,7 +50,8 @@ def serve( diagnose=False, ) # Setup OpenTelemetry distributed tracing - setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) + if otlp_endpoint is not None: + setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) server.serve(model_id, revision, sharded, quantize, uds_path) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 80aecbac..96dfeb67 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -1,6 +1,7 @@ import torch from dataclasses import dataclass +from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type @@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling +tracer = trace.get_tracer(__name__) + @dataclass class Seq2SeqLMBatch(Batch): @@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch): ) @classmethod + @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": """Concatenate multiple batches together by padding internal torch tensors""" @@ -324,6 +328,7 @@ class Seq2SeqLM(Model): def decode(self, decoder_ids: List[int]) -> str: return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) + @tracer.start_as_current_span("forward") def forward( self, input_ids, @@ -361,6 +366,7 @@ class Seq2SeqLM(Model): outputs.past_key_values, ) + @tracer.start_as_current_span("generate_token") def generate_token( self, batch: Seq2SeqLMBatch ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: @@ -401,91 +407,94 @@ class Seq2SeqLM(Model): batch.decoder_input_ids, ) - # For each member of the batch - for i, ( - request, - input_length, - decoder_input_length, - logits, - next_token_chooser, - stopping_criteria, - input_tokens, - decoder_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - decoder_input_ids.view(1, -1), logits - ) - - # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) - new_decoder_input_length = decoder_input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( - next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria(next_token_id, next_token_text) - - if stop: - # Slice with decoder_input_length to remove padding - # Decode all tokens - output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) - - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - generated_text = None - next_batch_keep_indices.append(i) - next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) - next_batch_size += 1 - next_batch_input_lengths.append(input_length) - next_batch_decoder_input_lengths.append(new_decoder_input_length) - next_batch_max_input_length = max( - next_batch_max_input_length, input_length - ) - next_batch_max_decoder_input_length = max( - next_batch_max_decoder_input_length, new_decoder_input_length + with tracer.start_as_current_span("post_processing"): + # For each member of the batch + for i, ( + request, + input_length, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + input_tokens, + decoder_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + decoder_input_ids.view(1, -1), logits ) - # Prefill - if stopping_criteria.current_tokens == 1: - prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + # Append next token to decoder tokens + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) + new_decoder_input_length = decoder_input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, [float("nan")], prefill_texts + + # Evaluate stopping criteria + stop, reason = stopping_criteria(next_token_id, next_token_text) + + if stop: + # Slice with decoder_input_length to remove padding + # Decode all tokens + output_text = self.decode( + decoder_input_ids[-new_decoder_input_length:] + ) + + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + generated_text = None + next_batch_keep_indices.append(i) + next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) + next_batch_size += 1 + next_batch_input_lengths.append(input_length) + next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_max_input_length = max( + next_batch_max_input_length, input_length + ) + next_batch_max_decoder_input_length = max( + next_batch_max_decoder_input_length, new_decoder_input_length + ) + + # Prefill + if stopping_criteria.current_tokens == 1: + prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, [float("nan")], prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, ) - else: - prefill_tokens = None - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - generated_text, - ) - - generations.append(generation) + generations.append(generation) # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: diff --git a/server/text_generation/server.py b/server/text_generation/server.py index f9e10b87..f3129cb4 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -13,7 +13,7 @@ from text_generation.cache import Cache from text_generation.interceptor import ExceptionInterceptor from text_generation.models import Model, get_model from text_generation.pb import generate_pb2_grpc, generate_pb2 -from text_generation.tracing import OpenTelemetryAioServerInterceptorUnix +from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): @@ -101,7 +101,12 @@ def serve( logger.exception("Error when initializing model") raise - server = aio.server(interceptors=[ExceptionInterceptor(), OpenTelemetryAioServerInterceptorUnix()]) + server = aio.server( + interceptors=[ + ExceptionInterceptor(), + UDSOpenTelemetryAioServerInterceptor(), + ] + ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( TextGenerationService(model, Cache(), server_urls), server ) diff --git a/server/text_generation/tracing.py b/server/text_generation/tracing.py index d58f29b1..fc90a8ae 100644 --- a/server/text_generation/tracing.py +++ b/server/text_generation/tracing.py @@ -2,7 +2,9 @@ import grpc from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.instrumentation.grpc._aio_server import OpenTelemetryAioServerInterceptor +from opentelemetry.instrumentation.grpc._aio_server import ( + OpenTelemetryAioServerInterceptor, +) from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -14,15 +16,13 @@ from opentelemetry.sdk.trace.export import ( from typing import Optional -class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor): +class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): def __init__(self): super().__init__(trace.get_tracer(__name__)) - def _start_span( - self, handler_call_details, context, set_status_on_exception=False - ): + def _start_span(self, handler_call_details, context, set_status_on_exception=False): """ - Rewrite _start_span method to support Unix socket gRPC context + Rewrite _start_span method to support Unix Domain Socket gRPC contexts """ # standard attributes @@ -33,9 +33,7 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor): # if we have details about the call, split into service and method if handler_call_details.method: - service, method = handler_call_details.method.lstrip("/").split( - "/", 1 - ) + service, method = handler_call_details.method.lstrip("/").split("/", 1) attributes.update( { SpanAttributes.RPC_METHOD: method, @@ -59,17 +57,12 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor): ) -def setup_tracing(shard: int, otlp_endpoint: Optional[str]): - resource = Resource.create(attributes={"service.name": f"text-generation-server.{shard}"}) +def setup_tracing(shard: int, otlp_endpoint: str): + resource = Resource.create( + attributes={"service.name": f"text-generation-inference.server-{shard}"} + ) + span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) + span_processor = BatchSpanProcessor(span_exporter) trace.set_tracer_provider(TracerProvider(resource=resource)) - - if otlp_endpoint is None: - # span_exporter = ConsoleSpanExporter(out=open(os.devnull, "w")) - span_exporter = ConsoleSpanExporter() - span_processor = SimpleSpanProcessor(span_exporter) - else: - span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) - span_processor = BatchSpanProcessor(span_exporter) - trace.get_tracer_provider().add_span_processor(span_processor)