add tracing to rust router

This commit is contained in:
OlivierDehaene 2023-02-09 19:24:09 +01:00
parent 04015dfa90
commit b3cc379550
17 changed files with 615 additions and 168 deletions

273
Cargo.lock generated
View File

@ -424,6 +424,19 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "derive_builder" name = "derive_builder"
version = "0.9.0" version = "0.9.0"
@ -717,6 +730,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "grpc-metadata"
version = "0.1.0"
dependencies = [
"opentelemetry",
"tonic 0.6.2",
"tracing",
"tracing-opentelemetry",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.15" version = "0.3.15"
@ -1010,6 +1033,15 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.0" version = "0.7.0"
@ -1231,6 +1263,86 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "opentelemetry"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde"
dependencies = [
"async-trait",
"futures",
"futures-util",
"http",
"opentelemetry",
"opentelemetry-proto",
"prost 0.11.6",
"thiserror",
"tokio",
"tonic 0.8.3",
]
[[package]]
name = "opentelemetry-proto"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28"
dependencies = [
"futures",
"futures-util",
"opentelemetry",
"prost 0.11.6",
"tonic 0.8.3",
"tonic-build 0.8.4",
]
[[package]]
name = "opentelemetry_api"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
dependencies = [
"fnv",
"futures-channel",
"futures-util",
"indexmap",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
]
[[package]]
name = "opentelemetry_sdk"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
dependencies = [
"async-trait",
"crossbeam-channel",
"dashmap",
"fnv",
"futures-channel",
"futures-executor",
"futures-util",
"once_cell",
"opentelemetry_api",
"percent-encoding",
"rand",
"thiserror",
"tokio",
"tokio-stream",
]
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"
version = "6.3.1" version = "6.3.1"
@ -1332,6 +1444,16 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78"
dependencies = [
"proc-macro2",
"syn",
]
[[package]] [[package]]
name = "proc-macro-error" name = "proc-macro-error"
version = "1.0.4" version = "1.0.4"
@ -1372,7 +1494,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001" checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive 0.9.0",
]
[[package]]
name = "prost"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698"
dependencies = [
"bytes",
"prost-derive 0.11.6",
] ]
[[package]] [[package]]
@ -1388,13 +1520,35 @@ dependencies = [
"log", "log",
"multimap", "multimap",
"petgraph", "petgraph",
"prost", "prost 0.9.0",
"prost-types", "prost-types 0.9.0",
"regex", "regex",
"tempfile", "tempfile",
"which", "which",
] ]
[[package]]
name = "prost-build"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e"
dependencies = [
"bytes",
"heck 0.4.0",
"itertools 0.10.5",
"lazy_static",
"log",
"multimap",
"petgraph",
"prettyplease",
"prost 0.11.6",
"prost-types 0.11.6",
"regex",
"syn",
"tempfile",
"which",
]
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.9.0" version = "0.9.0"
@ -1408,6 +1562,19 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "prost-derive"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d"
dependencies = [
"anyhow",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.9.0" version = "0.9.0"
@ -1415,7 +1582,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a" checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost", "prost 0.9.0",
]
[[package]]
name = "prost-types"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788"
dependencies = [
"bytes",
"prost 0.11.6",
] ]
[[package]] [[package]]
@ -1523,6 +1700,15 @@ dependencies = [
"regex-syntax", "regex-syntax",
] ]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax",
]
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.28" version = "0.6.28"
@ -1891,11 +2077,12 @@ name = "text-generation-client"
version = "0.2.1" version = "0.2.1"
dependencies = [ dependencies = [
"futures", "futures",
"prost", "grpc-metadata",
"prost 0.9.0",
"thiserror", "thiserror",
"tokio", "tokio",
"tonic", "tonic 0.6.2",
"tonic-build", "tonic-build 0.6.2",
"tower", "tower",
"tracing", "tracing",
"tracing-error", "tracing-error",
@ -1925,6 +2112,8 @@ dependencies = [
"clap 4.0.22", "clap 4.0.22",
"futures", "futures",
"nohash-hasher", "nohash-hasher",
"opentelemetry",
"opentelemetry-otlp",
"parking_lot", "parking_lot",
"rand", "rand",
"serde", "serde",
@ -1935,6 +2124,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
@ -2148,8 +2338,8 @@ dependencies = [
"hyper-timeout", "hyper-timeout",
"percent-encoding", "percent-encoding",
"pin-project", "pin-project",
"prost", "prost 0.9.0",
"prost-derive", "prost-derive 0.9.0",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util 0.6.10", "tokio-util 0.6.10",
@ -2160,6 +2350,38 @@ dependencies = [
"tracing-futures", "tracing-futures",
] ]
[[package]]
name = "tonic"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64",
"bytes",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost 0.11.6",
"prost-derive 0.11.6",
"tokio",
"tokio-stream",
"tokio-util 0.7.4",
"tower",
"tower-layer",
"tower-service",
"tracing",
"tracing-futures",
]
[[package]] [[package]]
name = "tonic-build" name = "tonic-build"
version = "0.6.2" version = "0.6.2"
@ -2167,7 +2389,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9403f1bafde247186684b230dc6f38b5cd514584e8bec1dd32514be4745fa757" checksum = "9403f1bafde247186684b230dc6f38b5cd514584e8bec1dd32514be4745fa757"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"prost-build", "prost-build 0.9.0",
"quote",
"syn",
]
[[package]]
name = "tonic-build"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4"
dependencies = [
"prettyplease",
"proc-macro2",
"prost-build 0.11.6",
"quote", "quote",
"syn", "syn",
] ]
@ -2288,6 +2523,20 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-opentelemetry"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
dependencies = [
"once_cell",
"opentelemetry",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-serde" name = "tracing-serde"
version = "0.1.3" version = "0.1.3"
@ -2304,12 +2553,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [ dependencies = [
"matchers",
"nu-ansi-term", "nu-ansi-term",
"once_cell",
"regex",
"serde", "serde",
"serde_json", "serde_json",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
"tracing",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",
"tracing-serde", "tracing-serde",

View File

@ -2,6 +2,7 @@
members = [ members = [
"router", "router",
"router/client", "router/client",
"router/grpc-metadata",
"launcher" "launcher"
] ]

View File

@ -19,6 +19,8 @@ text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.24"
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
serde = "1.0.145" serde = "1.0.145"
@ -28,7 +30,8 @@ tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -5,6 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
futures = "^0.3" futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.9" prost = "^0.9"
thiserror = "^1.0" thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] } tokio = { version = "^1.21", features = ["sync"] }

View File

@ -2,8 +2,9 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::*; use tracing::instrument;
/// Text Generation Inference gRPC client /// Text Generation Inference gRPC client
#[derive(Clone)] #[derive(Clone)]
@ -38,12 +39,8 @@ impl Client {
/// Returns a list of uris or unix sockets of all shards /// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> { pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}); let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self let response = self.stub.service_discovery(request).await?;
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let urls = response let urls = response
.into_inner() .into_inner()
.urls .urls
@ -60,11 +57,8 @@ impl Client {
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {}); let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
self.stub self.stub.clear_cache(request).await?;
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
Ok(()) Ok(())
} }
@ -74,13 +68,8 @@ impl Client {
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self let response = self.stub.prefill(request).await?.into_inner();
.stub
.prefill(request)
.instrument(info_span!("prefill"))
.await?
.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
@ -93,13 +82,8 @@ impl Client {
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(DecodeRequest { batches }); let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self let response = self.stub.decode(request).await?.into_inner();
.stub
.decode(request)
.instrument(info_span!("decode"))
.await?
.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
} }

View File

@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation};
use futures::future::join_all; use futures::future::join_all;
use futures::future::select_all; use futures::future::select_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument;
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
pub struct ShardedClient { pub struct ShardedClient {
@ -38,6 +39,7 @@ impl ShardedClient {
} }
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
@ -51,6 +53,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
@ -66,6 +69,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,

View File

@ -0,0 +1,10 @@
[package]
name = "grpc-metadata"
version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.6"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"

View File

@ -0,0 +1,62 @@
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
use opentelemetry::global;
use opentelemetry::propagation::{Extractor, Injector};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Extract context metadata from a gRPC request's metadata
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);
impl<'a> Extractor for MetadataExtractor<'a> {
/// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|metadata| metadata.to_str().ok())
}
/// Collect all the keys from the MetadataMap.
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|key| match key {
tonic::metadata::KeyRef::Ascii(v) => v.as_str(),
tonic::metadata::KeyRef::Binary(v) => v.as_str(),
})
.collect::<Vec<_>>()
}
}
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
self.0.insert(key, val);
}
}
}
}
/// Get a context from the global context and inject the span into a gRPC request's metadata.
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::Span::current().context(),
&mut MetadataInjector(metadata),
)
})
}
pub trait InjectTelemetryContext {
fn inject_context(self) -> Self;
}
impl<T> InjectTelemetryContext for tonic::Request<T> {
fn inject_context(mut self) -> Self {
inject(self.metadata_mut());
self
}
}

View File

@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest; use crate::GenerateRequest;
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use opentelemetry::trace::TraceContextExt;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
@ -13,7 +14,8 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::{info_span, instrument, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inference struct /// Inference struct
#[derive(Clone)] #[derive(Clone)]
@ -69,6 +71,7 @@ impl Infer {
} }
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
@ -87,6 +90,8 @@ impl Infer {
self.queue.append(Entry { self.queue.append(Entry {
request: valid_request, request: valid_request,
response_tx, response_tx,
parent_span: info_span!("entry"),
batch_span: None,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,
@ -101,6 +106,7 @@ impl Infer {
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate( pub(crate) async fn generate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
@ -169,7 +175,6 @@ impl Infer {
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[instrument(skip(client, queue, shared))]
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
max_batch_size: usize, max_batch_size: usize,
@ -188,8 +193,10 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await { while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until // We loop until we do not receive any cached batch from the inference server (== until
@ -210,13 +217,15 @@ async fn batching_task(
}; };
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize) .next_batch(min_size, max_batch_size - batch_size as usize)
.await .await
{ {
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await; wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
@ -226,8 +235,20 @@ async fn batching_task(
} }
} }
} }
let next_batch_span = info_span!("batch");
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span for this entry/batch tuple
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
// Add link to span
entry_batch_span
.add_link(next_batch_span.context().span().span_context().clone());
// Update entry
entry.batch_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries).await; cached_batch = wrap_future(client.decode(batches), &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1; waiting_tokens += 1;
} }
} }
@ -235,6 +256,7 @@ async fn batching_task(
} }
/// Wrap a future inside a match statement to handle errors and send the responses to Infer /// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip(future))]
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
@ -253,6 +275,7 @@ async fn wrap_future(
} }
/// Send errors to Infer for all `entries` /// Send errors to Infer for all `entries`
#[instrument]
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) { fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| { entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
@ -264,6 +287,7 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) { fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| { generations.into_iter().for_each(|generation| {
// Get entry // Get entry
@ -272,6 +296,8 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.get(&generation.request_id) .get(&generation.request_id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation").entered();
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -1,9 +1,17 @@
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint /// Text Generation Inference webserver entrypoint
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -27,6 +35,8 @@ struct Args {
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
@ -43,14 +53,9 @@ fn main() -> Result<(), std::io::Error> {
tokenizer_name, tokenizer_name,
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint,
} = args; } = args;
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
@ -67,6 +72,8 @@ fn main() -> Result<(), std::io::Error> {
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output);
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
@ -96,3 +103,57 @@ fn main() -> Result<(), std::io::Error> {
Ok(()) Ok(())
}) })
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer
.json()
.flatten_event(true)
.with_span_list(false)
.boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer =
opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(trace::config().with_resource(Resource::new(vec![
KeyValue::new("service.name", "text-generation-inference.router"),
])))
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
};
}
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}

View File

@ -2,11 +2,14 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse; use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use opentelemetry::trace::TraceContextExt;
use std::cmp::min; use std::cmp::min;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Queue entry /// Queue entry
#[derive(Debug)] #[derive(Debug)]
@ -15,6 +18,10 @@ pub(crate) struct Entry {
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Request Span
pub parent_span: Span,
/// Batch Span
pub batch_span: Option<Span>,
/// Instant when this entry was created /// Instant when this entry was created
pub time: Instant, pub time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
@ -42,13 +49,17 @@ impl Queue {
} }
/// Append an entry to the queue /// Append an entry to the queue
#[instrument(skip(self))]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
// Unwrap is safe here // Unwrap is safe here
self.queue_sender.send(QueueCommand::Append(entry)).unwrap(); self.queue_sender
.send(QueueCommand::Append(entry, Span::current()))
.unwrap();
} }
// Get the next batch // Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
@ -63,6 +74,7 @@ impl Queue {
min_size, min_size,
max_size, max_size,
response_sender, response_sender,
span: Span::current(),
}) })
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
@ -77,15 +89,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
QueueCommand::Append(entry) => state.append(entry), QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
max_size, max_size,
response_sender, response_sender,
} => { span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size); let next_batch = state.next_batch(min_size, max_size);
response_sender.send(next_batch).unwrap_or(()); response_sender.send(next_batch).unwrap_or(());
} }),
} }
} }
} }
@ -131,6 +144,7 @@ impl State {
} }
} }
let next_batch_span = info_span!("batch");
let next_batch_size = min(self.entries.len(), max_size); let next_batch_size = min(self.entries.len(), max_size);
let mut batch_requests = Vec::with_capacity(next_batch_size); let mut batch_requests = Vec::with_capacity(next_batch_size);
@ -141,6 +155,13 @@ impl State {
self.entries self.entries
.drain(..next_batch_size) .drain(..next_batch_size)
.for_each(|(id, mut entry)| { .for_each(|(id, mut entry)| {
// Create a new span for this entry/batch tuple
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
// Add link to span
entry_batch_span.add_link(next_batch_span.context().span().span_context().clone());
// Update entry
entry.batch_span = Some(entry_batch_span);
batch_requests.push(Request { batch_requests.push(Request {
id, id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
@ -162,19 +183,20 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
Some((batch_entries, batch)) Some((batch_entries, batch, next_batch_span))
} }
} }
type NextBatch = (IntMap<u64, Entry>, Batch); type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)] #[derive(Debug)]
enum QueueCommand { enum QueueCommand {
Append(Entry), Append(Entry, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, max_size: usize,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
}, },
} }
@ -184,6 +206,7 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::{mpsc, Semaphore}; use tokio::sync::{mpsc, Semaphore};
use tracing::info_span;
fn default_entry() -> Entry { fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1)); let semaphore = Arc::new(Semaphore::new(1));
@ -208,6 +231,8 @@ mod tests {
}, },
}, },
response_tx, response_tx,
parent_span: info_span!("entry"),
batch_span: None,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,

View File

@ -18,7 +18,7 @@ use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -197,7 +197,7 @@ async fn generate_stream(
let mut error = false; let mut error = false;
let details = req.0.parameters.details; let details = req.0.parameters.details;
match infer.generate_stream(req.0).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => { Ok(mut response_stream) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {

View File

@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512; const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4; const MAX_STOP_SEQUENCES: usize = 4;
@ -36,6 +37,7 @@ impl Validation {
} }
/// Validate a payload and get the number of tokens in the input /// Validate a payload and get the number of tokens in the input
#[instrument(skip(self))]
pub(crate) async fn validate( pub(crate) async fn validate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
@ -44,7 +46,10 @@ impl Validation {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
// Send request to the background validation task // Send request to the background validation task
// Unwrap is safe here // Unwrap is safe here
self.sender.send((request, sender)).await.unwrap(); self.sender
.send((request, sender, Span::current()))
.await
.unwrap();
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
@ -97,10 +102,12 @@ fn validation_worker(
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
// Loop over requests // Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() { while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx response_tx
.send(validate(request, &tokenizer, max_input_length, &mut rng)) .send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(()) .unwrap_or(())
})
} }
} }
@ -203,6 +210,7 @@ fn validate(
type ValidationRequest = ( type ValidationRequest = (
GenerateRequest, GenerateRequest,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>, oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
Span,
); );
#[derive(Debug)] #[derive(Debug)]

View File

@ -4,6 +4,7 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typer import Argument
from typing import Optional from typing import Optional
from text_generation import server, utils from text_generation import server, utils
@ -19,9 +20,9 @@ def serve(
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
uds_path: Path = "/tmp/text-generation", uds_path: Path = "/tmp/text-generation",
otlp_endpoint: Optional[str] = None,
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = Argument(None, envvar="OTLP_ENDPOINT"),
): ):
if sharded: if sharded:
assert ( assert (
@ -49,6 +50,7 @@ def serve(
diagnose=False, diagnose=False,
) )
# Setup OpenTelemetry distributed tracing # Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path) server.serve(model_id, revision, sharded, quantize, uds_path)

View File

@ -1,6 +1,7 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class Seq2SeqLMBatch(Batch): class Seq2SeqLMBatch(Batch):
@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors""" """Concatenate multiple batches together by padding internal torch tensors"""
@ -324,6 +328,7 @@ class Seq2SeqLM(Model):
def decode(self, decoder_ids: List[int]) -> str: def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
@tracer.start_as_current_span("forward")
def forward( def forward(
self, self,
input_ids, input_ids,
@ -361,6 +366,7 @@ class Seq2SeqLM(Model):
outputs.past_key_values, outputs.past_key_values,
) )
@tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
@ -401,6 +407,7 @@ class Seq2SeqLM(Model):
batch.decoder_input_ids, batch.decoder_input_ids,
) )
with tracer.start_as_current_span("post_processing"):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
@ -436,7 +443,9 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) output_text = self.decode(
decoder_input_ids[-new_decoder_input_length:]
)
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):

View File

@ -13,7 +13,7 @@ from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import OpenTelemetryAioServerInterceptorUnix from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -101,7 +101,12 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
server = aio.server(interceptors=[ExceptionInterceptor(), OpenTelemetryAioServerInterceptorUnix()]) server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server TextGenerationService(model, Cache(), server_urls), server
) )

View File

@ -2,7 +2,9 @@ import grpc
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.grpc._aio_server import OpenTelemetryAioServerInterceptor from opentelemetry.instrumentation.grpc._aio_server import (
OpenTelemetryAioServerInterceptor,
)
from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
@ -14,15 +16,13 @@ from opentelemetry.sdk.trace.export import (
from typing import Optional from typing import Optional
class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor): class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
def __init__(self): def __init__(self):
super().__init__(trace.get_tracer(__name__)) super().__init__(trace.get_tracer(__name__))
def _start_span( def _start_span(self, handler_call_details, context, set_status_on_exception=False):
self, handler_call_details, context, set_status_on_exception=False
):
""" """
Rewrite _start_span method to support Unix socket gRPC context Rewrite _start_span method to support Unix Domain Socket gRPC contexts
""" """
# standard attributes # standard attributes
@ -33,9 +33,7 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
# if we have details about the call, split into service and method # if we have details about the call, split into service and method
if handler_call_details.method: if handler_call_details.method:
service, method = handler_call_details.method.lstrip("/").split( service, method = handler_call_details.method.lstrip("/").split("/", 1)
"/", 1
)
attributes.update( attributes.update(
{ {
SpanAttributes.RPC_METHOD: method, SpanAttributes.RPC_METHOD: method,
@ -59,17 +57,12 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
) )
def setup_tracing(shard: int, otlp_endpoint: Optional[str]): def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(attributes={"service.name": f"text-generation-server.{shard}"}) resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
trace.set_tracer_provider(TracerProvider(resource=resource)) )
if otlp_endpoint is None:
# span_exporter = ConsoleSpanExporter(out=open(os.devnull, "w"))
span_exporter = ConsoleSpanExporter()
span_processor = SimpleSpanProcessor(span_exporter)
else:
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter) span_processor = BatchSpanProcessor(span_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
trace.get_tracer_provider().add_span_processor(span_processor) trace.get_tracer_provider().add_span_processor(span_processor)