mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
add tracing to rust router
This commit is contained in:
parent
04015dfa90
commit
b3cc379550
273
Cargo.lock
generated
273
Cargo.lock
generated
@ -424,6 +424,19 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.9.0"
|
||||
@ -717,6 +730,16 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-metadata"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"opentelemetry",
|
||||
"tonic 0.6.2",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.15"
|
||||
@ -1010,6 +1033,15 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.0"
|
||||
@ -1231,6 +1263,86 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
|
||||
dependencies = [
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-otlp"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"http",
|
||||
"opentelemetry",
|
||||
"opentelemetry-proto",
|
||||
"prost 0.11.6",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic 0.8.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-proto"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"futures-util",
|
||||
"opentelemetry",
|
||||
"prost 0.11.6",
|
||||
"tonic 0.8.3",
|
||||
"tonic-build 0.8.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_api"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_sdk"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"crossbeam-channel",
|
||||
"dashmap",
|
||||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-executor",
|
||||
"futures-util",
|
||||
"once_cell",
|
||||
"opentelemetry_api",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "6.3.1"
|
||||
@ -1332,6 +1444,16 @@ version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
@ -1372,7 +1494,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive",
|
||||
"prost-derive 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive 0.11.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1388,13 +1520,35 @@ dependencies = [
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
"prost",
|
||||
"prost-types",
|
||||
"prost 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
"regex",
|
||||
"tempfile",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-build"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.4.0",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
"prettyplease",
|
||||
"prost 0.11.6",
|
||||
"prost-types 0.11.6",
|
||||
"regex",
|
||||
"syn",
|
||||
"tempfile",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-derive"
|
||||
version = "0.9.0"
|
||||
@ -1408,6 +1562,19 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-derive"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.10.5",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-types"
|
||||
version = "0.9.0"
|
||||
@ -1415,7 +1582,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost",
|
||||
"prost 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-types"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost 0.11.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1523,6 +1700,15 @@ dependencies = [
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
|
||||
dependencies = [
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.28"
|
||||
@ -1891,11 +2077,12 @@ name = "text-generation-client"
|
||||
version = "0.2.1"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"prost",
|
||||
"grpc-metadata",
|
||||
"prost 0.9.0",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tonic 0.6.2",
|
||||
"tonic-build 0.6.2",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-error",
|
||||
@ -1925,6 +2112,8 @@ dependencies = [
|
||||
"clap 4.0.22",
|
||||
"futures",
|
||||
"nohash-hasher",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"serde",
|
||||
@ -1935,6 +2124,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
@ -2148,8 +2338,8 @@ dependencies = [
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"prost-derive",
|
||||
"prost 0.9.0",
|
||||
"prost-derive 0.9.0",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util 0.6.10",
|
||||
@ -2160,6 +2350,38 @@ dependencies = [
|
||||
"tracing-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost 0.11.6",
|
||||
"prost-derive 0.11.6",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util 0.7.4",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
version = "0.6.2"
|
||||
@ -2167,7 +2389,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9403f1bafde247186684b230dc6f38b5cd514584e8bec1dd32514be4745fa757"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"prost-build",
|
||||
"prost-build 0.9.0",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4"
|
||||
dependencies = [
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost-build 0.11.6",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
@ -2288,6 +2523,20 @@ dependencies = [
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-opentelemetry"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-serde"
|
||||
version = "0.1.3"
|
||||
@ -2304,12 +2553,16 @@ version = "0.3.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
"tracing-serde",
|
||||
|
@ -2,6 +2,7 @@
|
||||
members = [
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
|
||||
|
@ -19,6 +19,8 @@ text-generation-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
futures = "0.3.24"
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
parking_lot = "0.12.1"
|
||||
rand = "0.8.5"
|
||||
serde = "1.0.145"
|
||||
@ -28,7 +30,8 @@ tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.11"
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
|
||||
|
@ -5,6 +5,7 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
prost = "^0.9"
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.21", features = ["sync"] }
|
||||
|
@ -2,8 +2,9 @@
|
||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::Result;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::*;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Clone)]
|
||||
@ -38,12 +39,8 @@ impl Client {
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {});
|
||||
let response = self
|
||||
.stub
|
||||
.service_discovery(request)
|
||||
.instrument(info_span!("service_discovery"))
|
||||
.await?;
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
@ -60,11 +57,8 @@ impl Client {
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest {});
|
||||
self.stub
|
||||
.clear_cache(request)
|
||||
.instrument(info_span!("clear_cache"))
|
||||
.await?;
|
||||
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -74,13 +68,8 @@ impl Client {
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
|
||||
let response = self
|
||||
.stub
|
||||
.prefill(request)
|
||||
.instrument(info_span!("prefill"))
|
||||
.await?
|
||||
.into_inner();
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
}
|
||||
|
||||
@ -93,13 +82,8 @@ impl Client {
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches });
|
||||
let response = self
|
||||
.stub
|
||||
.decode(request)
|
||||
.instrument(info_span!("decode"))
|
||||
.await?
|
||||
.into_inner();
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation};
|
||||
use futures::future::join_all;
|
||||
use futures::future::select_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
@ -38,6 +39,7 @@ impl ShardedClient {
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
@ -51,6 +53,7 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
@ -66,6 +69,7 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
|
10
router/grpc-metadata/Cargo.toml
Normal file
10
router/grpc-metadata/Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "grpc-metadata"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
opentelemetry = "0.18.0"
|
||||
tonic = "^0.6"
|
||||
tracing = "^0.1"
|
||||
tracing-opentelemetry = "0.18.0"
|
62
router/grpc-metadata/src/lib.rs
Normal file
62
router/grpc-metadata/src/lib.rs
Normal file
@ -0,0 +1,62 @@
|
||||
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
|
||||
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
|
||||
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::propagation::{Extractor, Injector};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// Extract context metadata from a gRPC request's metadata
|
||||
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);
|
||||
|
||||
impl<'a> Extractor for MetadataExtractor<'a> {
|
||||
/// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None
|
||||
fn get(&self, key: &str) -> Option<&str> {
|
||||
self.0.get(key).and_then(|metadata| metadata.to_str().ok())
|
||||
}
|
||||
|
||||
/// Collect all the keys from the MetadataMap.
|
||||
fn keys(&self) -> Vec<&str> {
|
||||
self.0
|
||||
.keys()
|
||||
.map(|key| match key {
|
||||
tonic::metadata::KeyRef::Ascii(v) => v.as_str(),
|
||||
tonic::metadata::KeyRef::Binary(v) => v.as_str(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Inject context in the metadata of a gRPC request.
|
||||
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
|
||||
|
||||
impl<'a> Injector for MetadataInjector<'a> {
|
||||
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
||||
fn set(&mut self, key: &str, value: String) {
|
||||
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
|
||||
self.0.insert(key, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a context from the global context and inject the span into a gRPC request's metadata.
|
||||
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
propagator.inject_context(
|
||||
&tracing::Span::current().context(),
|
||||
&mut MetadataInjector(metadata),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub trait InjectTelemetryContext {
|
||||
fn inject_context(self) -> Self;
|
||||
}
|
||||
|
||||
impl<T> InjectTelemetryContext for tonic::Request<T> {
|
||||
fn inject_context(mut self) -> Self {
|
||||
inject(self.metadata_mut());
|
||||
self
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError};
|
||||
use crate::GenerateRequest;
|
||||
use crate::{Entry, Queue, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
@ -13,7 +14,8 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// Inference struct
|
||||
#[derive(Clone)]
|
||||
@ -69,6 +71,7 @@ impl Infer {
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
@ -87,6 +90,8 @@ impl Infer {
|
||||
self.queue.append(Entry {
|
||||
request: valid_request,
|
||||
response_tx,
|
||||
parent_span: info_span!("entry"),
|
||||
batch_span: None,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
@ -101,6 +106,7 @@ impl Infer {
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
@ -169,7 +175,6 @@ impl Infer {
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[instrument(skip(client, queue, shared))]
|
||||
async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
@ -188,8 +193,10 @@ async fn batching_task(
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await {
|
||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
|
||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
@ -210,13 +217,15 @@ async fn batching_task(
|
||||
};
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch)) = queue
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
.await
|
||||
{
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.prefill(new_batch), &mut new_entries).await;
|
||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
@ -226,8 +235,20 @@ async fn batching_task(
|
||||
}
|
||||
}
|
||||
}
|
||||
let next_batch_span = info_span!("batch");
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span for this entry/batch tuple
|
||||
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
|
||||
// Add link to span
|
||||
entry_batch_span
|
||||
.add_link(next_batch_span.context().span().span_context().clone());
|
||||
// Update entry
|
||||
entry.batch_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
|
||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
}
|
||||
@ -235,6 +256,7 @@ async fn batching_task(
|
||||
}
|
||||
|
||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||
#[instrument(skip(future))]
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
@ -253,6 +275,7 @@ async fn wrap_future(
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument]
|
||||
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
@ -264,6 +287,7 @@ fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
#[instrument]
|
||||
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
// Get entry
|
||||
@ -272,6 +296,8 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
||||
.get(&generation.request_id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation").entered();
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
|
@ -1,9 +1,17 @@
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
/// Text Generation Inference webserver entrypoint
|
||||
use clap::Parser;
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
@ -27,6 +35,8 @@ struct Args {
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
@ -43,14 +53,9 @@ fn main() -> Result<(), std::io::Error> {
|
||||
tokenizer_name,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
} = args;
|
||||
|
||||
if json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
panic!("validation_workers must be > 0");
|
||||
}
|
||||
@ -67,6 +72,8 @@ fn main() -> Result<(), std::io::Error> {
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
@ -96,3 +103,57 @@ fn main() -> Result<(), std::io::Error> {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
true => fmt_layer
|
||||
.json()
|
||||
.flatten_event(true)
|
||||
.with_span_list(false)
|
||||
.boxed(),
|
||||
false => fmt_layer.boxed(),
|
||||
};
|
||||
layers.push(fmt_layer);
|
||||
|
||||
// OpenTelemetry tracing layer
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let tracer =
|
||||
opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(
|
||||
opentelemetry_otlp::new_exporter()
|
||||
.tonic()
|
||||
.with_endpoint(otlp_endpoint),
|
||||
)
|
||||
.with_trace_config(trace::config().with_resource(Resource::new(vec![
|
||||
KeyValue::new("service.name", "text-generation-inference.router"),
|
||||
])))
|
||||
.install_batch(opentelemetry::runtime::Tokio);
|
||||
|
||||
if let Ok(tracer) = tracer {
|
||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||
};
|
||||
}
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let env_filter =
|
||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(layers)
|
||||
.init();
|
||||
}
|
||||
|
@ -2,11 +2,14 @@ use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use std::cmp::min;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// Queue entry
|
||||
#[derive(Debug)]
|
||||
@ -15,6 +18,10 @@ pub(crate) struct Entry {
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Request Span
|
||||
pub parent_span: Span,
|
||||
/// Batch Span
|
||||
pub batch_span: Option<Span>,
|
||||
/// Instant when this entry was created
|
||||
pub time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
@ -42,13 +49,17 @@ impl Queue {
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender.send(QueueCommand::Append(entry)).unwrap();
|
||||
self.queue_sender
|
||||
.send(QueueCommand::Append(entry, Span::current()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
@ -63,6 +74,7 @@ impl Queue {
|
||||
min_size,
|
||||
max_size,
|
||||
response_sender,
|
||||
span: Span::current(),
|
||||
})
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
@ -77,15 +89,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry) => state.append(entry),
|
||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
response_sender,
|
||||
} => {
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch = state.next_batch(min_size, max_size);
|
||||
response_sender.send(next_batch).unwrap_or(());
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -131,6 +144,7 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
let next_batch_span = info_span!("batch");
|
||||
let next_batch_size = min(self.entries.len(), max_size);
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(next_batch_size);
|
||||
@ -141,6 +155,13 @@ impl State {
|
||||
self.entries
|
||||
.drain(..next_batch_size)
|
||||
.for_each(|(id, mut entry)| {
|
||||
// Create a new span for this entry/batch tuple
|
||||
let entry_batch_span = info_span!(parent: &entry.parent_span, "infer");
|
||||
// Add link to span
|
||||
entry_batch_span.add_link(next_batch_span.context().span().span_context().clone());
|
||||
// Update entry
|
||||
entry.batch_span = Some(entry_batch_span);
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
@ -162,19 +183,20 @@ impl State {
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
Some((batch_entries, batch))
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
}
|
||||
|
||||
type NextBatch = (IntMap<u64, Entry>, Batch);
|
||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum QueueCommand {
|
||||
Append(Entry),
|
||||
Append(Entry, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
@ -184,6 +206,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tokio::sync::{mpsc, Semaphore};
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> Entry {
|
||||
let semaphore = Arc::new(Semaphore::new(1));
|
||||
@ -208,6 +231,8 @@ mod tests {
|
||||
},
|
||||
},
|
||||
response_tx,
|
||||
parent_span: info_span!("entry"),
|
||||
batch_span: None,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
|
@ -18,7 +18,7 @@ use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
@ -197,7 +197,7 @@ async fn generate_stream(
|
||||
let mut error = false;
|
||||
let details = req.0.parameters.details;
|
||||
|
||||
match infer.generate_stream(req.0).await {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
Ok(mut response_stream) => {
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
|
@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
const MAX_MAX_NEW_TOKENS: u32 = 512;
|
||||
const MAX_STOP_SEQUENCES: usize = 4;
|
||||
@ -36,6 +37,7 @@ impl Validation {
|
||||
}
|
||||
|
||||
/// Validate a payload and get the number of tokens in the input
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn validate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
@ -44,7 +46,10 @@ impl Validation {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
self.sender.send((request, sender)).await.unwrap();
|
||||
self.sender
|
||||
.send((request, sender, Span::current()))
|
||||
.await
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
receiver.await.unwrap()
|
||||
@ -97,10 +102,12 @@ fn validation_worker(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Loop over requests
|
||||
while let Some((request, response_tx)) = receiver.blocking_recv() {
|
||||
response_tx
|
||||
.send(validate(request, &tokenizer, max_input_length, &mut rng))
|
||||
.unwrap_or(())
|
||||
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(validate(request, &tokenizer, max_input_length, &mut rng))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,6 +210,7 @@ fn validate(
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -4,6 +4,7 @@ import typer
|
||||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typer import Argument
|
||||
from typing import Optional
|
||||
|
||||
from text_generation import server, utils
|
||||
@ -19,9 +20,9 @@ def serve(
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation",
|
||||
otlp_endpoint: Optional[str] = None,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
otlp_endpoint: Optional[str] = Argument(None, envvar="OTLP_ENDPOINT"),
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
@ -49,7 +50,8 @@ def serve(
|
||||
diagnose=False,
|
||||
)
|
||||
# Setup OpenTelemetry distributed tracing
|
||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||
if otlp_endpoint is not None:
|
||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||
|
||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqLMBatch(Batch):
|
||||
@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||
"""Concatenate multiple batches together by padding internal torch tensors"""
|
||||
|
||||
@ -324,6 +328,7 @@ class Seq2SeqLM(Model):
|
||||
def decode(self, decoder_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||
|
||||
@tracer.start_as_current_span("forward")
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@ -361,6 +366,7 @@ class Seq2SeqLM(Model):
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: Seq2SeqLMBatch
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||
@ -401,91 +407,94 @@ class Seq2SeqLM(Model):
|
||||
batch.decoder_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
with tracer.start_as_current_span("post_processing"):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(
|
||||
decoder_input_ids[-new_decoder_input_length:]
|
||||
)
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
generations.append(generation)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
|
@ -13,7 +13,7 @@ from text_generation.cache import Cache
|
||||
from text_generation.interceptor import ExceptionInterceptor
|
||||
from text_generation.models import Model, get_model
|
||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation.tracing import OpenTelemetryAioServerInterceptorUnix
|
||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
@ -101,7 +101,12 @@ def serve(
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
server = aio.server(interceptors=[ExceptionInterceptor(), OpenTelemetryAioServerInterceptorUnix()])
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
UDSOpenTelemetryAioServerInterceptor(),
|
||||
]
|
||||
)
|
||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||
TextGenerationService(model, Cache(), server_urls), server
|
||||
)
|
||||
|
@ -2,7 +2,9 @@ import grpc
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.grpc._aio_server import OpenTelemetryAioServerInterceptor
|
||||
from opentelemetry.instrumentation.grpc._aio_server import (
|
||||
OpenTelemetryAioServerInterceptor,
|
||||
)
|
||||
from opentelemetry.semconv.trace import SpanAttributes
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
@ -14,15 +16,13 @@ from opentelemetry.sdk.trace.export import (
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
|
||||
class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
|
||||
def __init__(self):
|
||||
super().__init__(trace.get_tracer(__name__))
|
||||
|
||||
def _start_span(
|
||||
self, handler_call_details, context, set_status_on_exception=False
|
||||
):
|
||||
def _start_span(self, handler_call_details, context, set_status_on_exception=False):
|
||||
"""
|
||||
Rewrite _start_span method to support Unix socket gRPC context
|
||||
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
|
||||
"""
|
||||
|
||||
# standard attributes
|
||||
@ -33,9 +33,7 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
|
||||
|
||||
# if we have details about the call, split into service and method
|
||||
if handler_call_details.method:
|
||||
service, method = handler_call_details.method.lstrip("/").split(
|
||||
"/", 1
|
||||
)
|
||||
service, method = handler_call_details.method.lstrip("/").split("/", 1)
|
||||
attributes.update(
|
||||
{
|
||||
SpanAttributes.RPC_METHOD: method,
|
||||
@ -59,17 +57,12 @@ class OpenTelemetryAioServerInterceptorUnix(OpenTelemetryAioServerInterceptor):
|
||||
)
|
||||
|
||||
|
||||
def setup_tracing(shard: int, otlp_endpoint: Optional[str]):
|
||||
resource = Resource.create(attributes={"service.name": f"text-generation-server.{shard}"})
|
||||
def setup_tracing(shard: int, otlp_endpoint: str):
|
||||
resource = Resource.create(
|
||||
attributes={"service.name": f"text-generation-inference.server-{shard}"}
|
||||
)
|
||||
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
|
||||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
|
||||
if otlp_endpoint is None:
|
||||
# span_exporter = ConsoleSpanExporter(out=open(os.devnull, "w"))
|
||||
span_exporter = ConsoleSpanExporter()
|
||||
span_processor = SimpleSpanProcessor(span_exporter)
|
||||
else:
|
||||
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
|
Loading…
Reference in New Issue
Block a user