add tracing to rust router

This commit is contained in:
OlivierDehaene 2023-02-09 19:24:09 +01:00
parent 04015dfa90
commit b3cc379550
17 changed files with 615 additions and 168 deletions

273
Cargo.lock generated
View File

@ -424,6 +424,19 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "derive_builder"
version = "0.9.0"
@ -717,6 +730,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "grpc-metadata"
version = "0.1.0"
dependencies = [
"opentelemetry",
"tonic 0.6.2",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "h2"
version = "0.3.15"
@ -1010,6 +1033,15 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.7.0"
@ -1231,6 +1263,86 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "opentelemetry"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde"
dependencies = [
"async-trait",
"futures",
"futures-util",
"http",
"opentelemetry",
"opentelemetry-proto",
"prost 0.11.6",
"thiserror",
"tokio",
"tonic 0.8.3",
]
[[package]]
name = "opentelemetry-proto"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28"
dependencies = [
"futures",
"futures-util",
"opentelemetry",
"prost 0.11.6",
"tonic 0.8.3",
"tonic-build 0.8.4",
]
[[package]]
name = "opentelemetry_api"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
dependencies = [
"fnv",
"futures-channel",
"futures-util",
"indexmap",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
]
[[package]]
name = "opentelemetry_sdk"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
dependencies = [
"async-trait",
"crossbeam-channel",
"dashmap",
"fnv",
"futures-channel",
"futures-executor",
"futures-util",
"once_cell",
"opentelemetry_api",
"percent-encoding",
"rand",
"thiserror",
"tokio",
"tokio-stream",
]
[[package]]
name = "os_str_bytes"
version = "6.3.1"
@ -1332,6 +1444,16 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
@ -1372,7 +1494,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001"
dependencies = [
"bytes",
"prost-derive",
"prost-derive 0.9.0",
]
[[package]]
name = "prost"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698"
dependencies = [
"bytes",
"prost-derive 0.11.6",
]
[[package]]
@ -1388,13 +1520,35 @@ dependencies = [
"log",
"multimap",
"petgraph",
"prost",
"prost-types",
"prost 0.9.0",
"prost-types 0.9.0",
"regex",
"tempfile",
"which",
]
[[package]]
name = "prost-build"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e"
dependencies = [
"bytes",
"heck 0.4.0",
"itertools 0.10.5",
"lazy_static",
"log",
"multimap",
"petgraph",
"prettyplease",
"prost 0.11.6",
"prost-types 0.11.6",
"regex",
"syn",
"tempfile",
"which",
]
[[package]]
name = "prost-derive"
version = "0.9.0"
@ -1408,6 +1562,19 @@ dependencies = [
"syn",
]
[[package]]
name = "prost-derive"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d"
dependencies = [
"anyhow",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "prost-types"
version = "0.9.0"
@ -1415,7 +1582,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a"
dependencies = [
"bytes",
"prost",
"prost 0.9.0",
]
[[package]]
name = "prost-types"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788"
dependencies = [
"bytes",
"prost 0.11.6",
]
[[package]]
@ -1523,6 +1700,15 @@ dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.28"
@ -1891,11 +2077,12 @@ name = "text-generation-client"
version = "0.2.1"
dependencies = [
"futures",
"prost",
"grpc-metadata",
"prost 0.9.0",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tonic 0.6.2",
"tonic-build 0.6.2",
"tower",
"tracing",
"tracing-error",
@ -1925,6 +2112,8 @@ dependencies = [
"clap 4.0.22",
"futures",
"nohash-hasher",
"opentelemetry",
"opentelemetry-otlp",
"parking_lot",
"rand",
"serde",
@ -1935,6 +2124,7 @@ dependencies = [
"tokio",
"tokio-stream",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
@ -2148,8 +2338,8 @@ dependencies = [
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost",
"prost-derive",
"prost 0.9.0",
"prost-derive 0.9.0",
"tokio",
"tokio-stream",
"tokio-util 0.6.10",
@ -2160,6 +2350,38 @@ dependencies = [
"tracing-futures",
]
[[package]]
name = "tonic"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64",
"bytes",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost 0.11.6",
"prost-derive 0.11.6",
"tokio",
"tokio-stream",
"tokio-util 0.7.4",
"tower",
"tower-layer",
"tower-service",
"tracing",
"tracing-futures",
]
[[package]]
name = "tonic-build"
version = "0.6.2"
@ -2167,7 +2389,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9403f1bafde247186684b230dc6f38b5cd514584e8bec1dd32514be4745fa757"
dependencies = [
"proc-macro2",
"prost-build",
"prost-build 0.9.0",
"quote",
"syn",
]
[[package]]
name = "tonic-build"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4"
dependencies = [
"prettyplease",
"proc-macro2",
"prost-build 0.11.6",
"quote",
"syn",
]
@ -2288,6 +2523,20 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
dependencies = [
"once_cell",
"opentelemetry",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
]
[[package]]
name = "tracing-serde"
version = "0.1.3"
@ -2304,12 +2553,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"serde",
"serde_json",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
"tracing-serde",

View File

@ -2,6 +2,7 @@
members = [
"router",
"router/client",
"router/grpc-metadata",
"launcher"
]

View File

@ -19,6 +19,8 @@ text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145"
@ -28,7 +30,8 @@ tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] }
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.9"
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }

View File

@ -2,8 +2,9 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri};
use tracing::*;
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Clone)]
@ -38,12 +39,8 @@ impl Client {
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {});
let response = self
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
let urls = response
.into_inner()
.urls
@ -60,11 +57,8 @@ impl Client {
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {});
self.stub
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
@ -74,13 +68,8 @@ impl Client {
/// and the next cached batch
#[instrument(skip(self))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
let response = self
.stub
.prefill(request)
.instrument(info_span!("prefill"))
.await?
.into_inner();
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
}
@ -93,13 +82,8 @@ impl Client {
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(DecodeRequest { batches });
let response = self
.stub
.decode(request)
.instrument(info_span!("decode"))
.await?
.into_inner();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
}
}

View File

@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation};
use futures::future::join_all;
use futures::future::select_all;
use tonic::transport::Uri;
use tracing::instrument;
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
@ -38,6 +39,7 @@ impl ShardedClient {
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self
.clients
@ -51,6 +53,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip(self))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
@ -66,6 +69,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip(self))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,

View File

@ -0,0 +1,10 @@
[package]
name = "grpc-metadata"
version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.6"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"

View File

@ -0,0 +1,62 @@
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
use opentelemetry::global;
use opentelemetry::propagation::{Extractor, Injector};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Extract context metadata from a gRPC request's metadata
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);
impl<'a> Extractor for MetadataExtractor<'a> {
/// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|metadata| metadata.to_str().ok())
}
/// Collect all the keys from the MetadataMap.
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|key| match key {
tonic::metadata::KeyRef::Ascii(v) => v.as_str(),
tonic::metadata::KeyRef::Binary(v) => v.as_str(),
})
.collect::<Vec<_>>()
}
}
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
self.0.insert(key, val);
}
}
}
}
/// Get a context from the global context and inject the span into a gRPC request's metadata.
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::Span::current().context(),
&mut MetadataInjector(metadata),
)
})
}
pub trait InjectTelemetryContext {
fn inject_context(self) -> Self;
}
impl<T> InjectTelemetryContext for tonic::Request<T> {
fn inject_context(mut self) -> Self {
inject(self.metadata_mut());
self
}
}

View File

@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token};
use nohash_hasher::IntMap;
use opentelemetry::trace::TraceContextExt;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{
@ -13,7 +14,8 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
use tracing::{info_span, instrument, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inference struct
#[derive(Clone)]
@ -69,6 +71,7 @@ impl Infer {
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
@ -87,6 +90,8 @@ impl Infer {
self.queue.append(Entry {
request: valid_request,
response_tx,
parent_span: info_span!("entry"),
batch_span: None,
time: Instant::now(),
batch_time: None,
_permit: permit,
@ -101,6 +106,7 @@ impl Infer {
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate(
&self,
request: GenerateRequest,
@ -169,7 +175,6 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, queue, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
@ -188,8 +193,10 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
@ -210,13 +217,15 @@ async fn batching_task(
};
// Try to get a new batch
if let Some((mut new_entries, new_batch)) = queue
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await;
wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
@ -226,8 +235,20 @@ async fn batching_task(
}
}
}
let next_batch_span = info_span!("batch");
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span for this entry/batch tuple
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
// Add link to span
entry_batch_span
.add_link(next_batch_span.context().span().span_context().clone());
// Update entry
entry.batch_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
cached_batch = wrap_future(client.decode(batches), &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
}
@ -235,6 +256,7 @@ async fn batching_task(
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip(future))]
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
@ -253,6 +275,7 @@ async fn wrap_future(
}
/// Send errors to Infer for all `entries`
#[instrument]
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone.
@ -264,6 +287,7 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
@ -272,6 +296,8 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.get(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation").entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -1,9 +1,17 @@
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
@ -27,6 +35,8 @@ struct Args {
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
}
fn main() -> Result<(), std::io::Error> {
@ -43,14 +53,9 @@ fn main() -> Result<(), std::io::Error> {
tokenizer_name,
validation_workers,
json_output,
otlp_endpoint,
} = args;
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
if validation_workers == 0 {
panic!("validation_workers must be > 0");
}
@ -67,6 +72,8 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
init_logging(otlp_endpoint, json_output);
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
@ -96,3 +103,57 @@ fn main() -> Result<(), std::io::Error> {
Ok(())
})
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer
.json()
.flatten_event(true)
.with_span_list(false)
.boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer =
opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(trace::config().with_resource(Resource::new(vec![
KeyValue::new("service.name", "text-generation-inference.router"),
])))
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
};
}
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}

View File

@ -2,11 +2,14 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use opentelemetry::trace::TraceContextExt;
use std::cmp::min;
use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Queue entry
#[derive(Debug)]
@ -15,6 +18,10 @@ pub(crate) struct Entry {
pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Request Span
pub parent_span: Span,
/// Batch Span
pub batch_span: Option<Span>,
/// Instant when this entry was created
pub time: Instant,
/// Instant when this entry was added to a batch
@ -42,13 +49,17 @@ impl Queue {
}
/// Append an entry to the queue
#[instrument(skip(self))]
pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state
// Unwrap is safe here
self.queue_sender.send(QueueCommand::Append(entry)).unwrap();
self.queue_sender
.send(QueueCommand::Append(entry, Span::current()))
.unwrap();
}
// Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
@ -63,6 +74,7 @@ impl Queue {
min_size,
max_size,
response_sender,
span: Span::current(),
})
.unwrap();
// Await on response channel
@ -77,15 +89,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
while let Some(cmd) = receiver.recv().await {
match cmd {
QueueCommand::Append(entry) => state.append(entry),
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch {
min_size,
max_size,
response_sender,
} => {
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size);
response_sender.send(next_batch).unwrap_or(());
}
}),
}
}
}
@ -131,6 +144,7 @@ impl State {
}
}
let next_batch_span = info_span!("batch");
let next_batch_size = min(self.entries.len(), max_size);
let mut batch_requests = Vec::with_capacity(next_batch_size);
@ -141,6 +155,13 @@ impl State {
self.entries
.drain(..next_batch_size)
.for_each(|(id, mut entry)| {
// Create a new span for this entry/batch tuple
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
// Add link to span
entry_batch_span.add_link(next_batch_span.context().span().span_context().clone());
// Update entry
entry.batch_span = Some(entry_batch_span);
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
@ -162,19 +183,20 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
Some((batch_entries, batch))
Some((batch_entries, batch, next_batch_span))
}
}
type NextBatch = (IntMap<u64, Entry>, Batch);
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
enum QueueCommand {
Append(Entry),
Append(Entry, Span),
NextBatch {
min_size: Option<usize>,
max_size: usize,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
},
}
@ -184,6 +206,7 @@ mod tests {
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::{mpsc, Semaphore};
use tracing::info_span;
fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
@ -208,6 +231,8 @@ mod tests {
},
},
response_tx,
parent_span: info_span!("entry"),
batch_span: None,
time: Instant::now(),
batch_time: None,
_permit: permit,

View File

@ -18,7 +18,7 @@ use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tracing::instrument;
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
@ -197,7 +197,7 @@ async fn generate_stream(
let mut error = false;
let details = req.0.parameters.details;
match infer.generate_stream(req.0).await {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {

View File

@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
@ -36,6 +37,7 @@ impl Validation {
}
/// Validate a payload and get the number of tokens in the input
#[instrument(skip(self))]
pub(crate) async fn validate(
&self,
request: GenerateRequest,
@ -44,7 +46,10 @@ impl Validation {
let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender.send((request, sender)).await.unwrap();
self.sender
.send((request, sender, Span::current()))
.await
.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap()
@ -97,10 +102,12 @@ fn validation_worker(
let mut rng = rand::thread_rng();
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx
.send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(())
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx
.send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(())
})
}
}
@ -203,6 +210,7 @@ fn validate(
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
Span,
);
#[derive(Debug)]

View File

@ -4,6 +4,7 @@ import typer
from pathlib import Path
from loguru import logger
from typer import Argument
from typing import Optional
from text_generation import server, utils
@ -19,9 +20,9 @@ def serve(
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
otlp_endpoint: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = Argument(None, envvar="OTLP_ENDPOINT"),
):
if sharded:
assert (
@ -49,7 +50,8 @@ def serve(
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass
class Seq2SeqLMBatch(Batch):
@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
@ -324,6 +328,7 @@ class Seq2SeqLM(Model):
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
@tracer.start_as_current_span("forward")
def forward(
self,
input_ids,
@ -361,6 +366,7 @@ class Seq2SeqLM(Model):
outputs.past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
@ -401,91 +407,94 @@ class Seq2SeqLM(Model):
batch.decoder_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
with tracer.start_as_current_span("post_processing"):
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(
decoder_input_ids[-new_decoder_input_length:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
generations.append(generation)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:

View File

@ -13,7 +13,7 @@ from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import OpenTelemetryAioServerInterceptorUnix
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -101,7 +101,12 @@ def serve(
logger.exception("Error when initializing model")
raise
server = aio.server(interceptors=[ExceptionInterceptor(), OpenTelemetryAioServerInterceptorUnix()])
server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server
)

View File

@ -2,7 +2,9 @@ import grpc
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.grpc._aio_server import OpenTelemetryAioServerInterceptor
from opentelemetry.instrumentation.grpc._aio_server import (
OpenTelemetryAioServerInterceptor,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
@ -14,15 +16,13 @@ from opentelemetry.sdk.trace.export import (
from typing import Optional
class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
def __init__(self):
super().__init__(trace.get_tracer(__name__))
def _start_span(
self, handler_call_details, context, set_status_on_exception=False
):
def _start_span(self, handler_call_details, context, set_status_on_exception=False):
"""
Rewrite _start_span method to support Unix socket gRPC context
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
"""
# standard attributes
@ -33,9 +33,7 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
# if we have details about the call, split into service and method
if handler_call_details.method:
service, method = handler_call_details.method.lstrip("/").split(
"/", 1
)
service, method = handler_call_details.method.lstrip("/").split("/", 1)
attributes.update(
{
SpanAttributes.RPC_METHOD: method,
@ -59,17 +57,12 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
)
def setup_tracing(shard: int, otlp_endpoint: Optional[str]):
resource = Resource.create(attributes={"service.name": f"text-generation-server.{shard}"})
def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
if otlp_endpoint is None:
# span_exporter = ConsoleSpanExporter(out=open(os.devnull, "w"))
span_exporter = ConsoleSpanExporter()
span_processor = SimpleSpanProcessor(span_exporter)
else:
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)