Update dependencies

This commit is contained in:
OlivierDehaene 2023-02-10 12:14:29 +01:00
parent 7fa81a05b0
commit 1e5a30990b
15 changed files with 440 additions and 456 deletions

View File

@ -23,6 +23,8 @@ jobs:
toolchain: 1.65.0 toolchain: 1.65.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Loading cache. - name: Loading cache.
uses: actions/cache@v2 uses: actions/cache@v2
id: model_cache id: model_cache

769
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,6 @@
FROM rust:1.65 as router-builder FROM rust:1.67 as router-builder
RUN apt-get update && apt-get install -y protobuf-compiler && rm -rf /var/lib/apt/lists/*
WORKDIR /usr/src WORKDIR /usr/src
@ -10,7 +12,7 @@ WORKDIR /usr/src/router
RUN cargo install --path . RUN cargo install --path .
FROM rust:1.65 as launcher-builder FROM rust:1.67 as launcher-builder
WORKDIR /usr/src WORKDIR /usr/src

View File

@ -142,6 +142,20 @@ conda create -n text-generation-inference python=3.9
conda activate text-generation-inference conda activate text-generation-inference
``` ```
You may also need to install Protoc.
On Linux:
```shell
apt install -y protobuf-compiler
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run: Then run:
```shell ```shell

View File

@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
description = "Text Generation Launcher" description = "Text Generation Launcher"
[dependencies] [dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.3", features = ["termination"] } ctrlc = { version = "3.2.5", features = ["termination"] }
serde_json = "1.0.89" serde_json = "1.0.93"
subprocess = "0.2.9" subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] } tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] } reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.150", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }

View File

@ -165,7 +165,7 @@ fn main() -> ExitCode {
"--port".to_string(), "--port".to_string(),
port.to_string(), port.to_string(),
"--master-shard-uds-path".to_string(), "--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path), format!("{shard_uds_path}-0"),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
model_id, model_id,
]; ];
@ -269,7 +269,7 @@ fn shard_manager(
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
) { ) {
// Get UDS path // Get UDS path
let uds_string = format!("{}-{}", uds_path, rank); let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string); let uds = Path::new(&uds_string);
// Clean previous runs // Clean previous runs
fs::remove_file(uds).unwrap_or_default(); fs::remove_file(uds).unwrap_or_default();

View File

@ -15,23 +15,24 @@ path = "src/main.rs"
[dependencies] [dependencies]
async-stream = "0.3.3" async-stream = "0.3.3"
axum = { version = "0.6.4", features = ["json"] } axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.26"
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
serde = "1.0.145" serde = "1.0.152"
serde_json = "1.0.85" serde_json = "1.0.93"
thiserror = "1.0.37" thiserror = "1.0.38"
tokenizers = "0.13.0" tokenizers = "0.13.2"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tracing = "0.1.36" tracing = "0.1.37"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -6,13 +6,13 @@ edition = "2021"
[dependencies] [dependencies]
futures = "^0.3" futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.9" prost = "^0.11"
thiserror = "^1.0" thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] } tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.6" tonic = "^0.8"
tower = "^0.4" tower = "^0.4"
tracing = "^0.1" tracing = "^0.1"
tracing-error = "^0.2" tracing-error = "^0.2"
[build-dependencies] [build-dependencies]
tonic-build = "0.6.2" tonic-build = "^0.8"

View File

@ -9,7 +9,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.out_dir("src/pb") .out_dir("src/pb")
.include_file("mod.rs") .include_file("mod.rs")
.compile(&["../../proto/generate.proto"], &["../../proto"]) .compile(&["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(()) Ok(())
} }

View File

@ -5,6 +5,6 @@ edition = "2021"
[dependencies] [dependencies]
opentelemetry = "0.18.0" opentelemetry = "0.18.0"
tonic = "^0.6" tonic = "^0.8"
tracing = "^0.1" tracing = "^0.1"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.18.0"

View File

@ -33,7 +33,7 @@ impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) { fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) { if let Ok(val) = value.parse() {
self.0.insert(key, val); self.0.insert(key, val);
} }
} }

View File

@ -2,6 +2,7 @@
use clap::Parser; use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource; use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
@ -130,21 +131,26 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new()); global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = let tracer = opentelemetry_otlp::new_pipeline()
opentelemetry_otlp::new_pipeline() .tracing()
.tracing() .with_exporter(
.with_exporter( opentelemetry_otlp::new_exporter()
opentelemetry_otlp::new_exporter() .tonic()
.tonic() .with_endpoint(otlp_endpoint),
.with_endpoint(otlp_endpoint), )
) .with_trace_config(
.with_trace_config(trace::config().with_resource(Resource::new(vec![ trace::config()
KeyValue::new("service.name", "text-generation-inference.router"), .with_resource(Resource::new(vec![KeyValue::new(
]))) "service.name",
.install_batch(opentelemetry::runtime::Tokio); "text-generation-inference.router",
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer { if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
axum_tracing_opentelemetry::init_propagator().unwrap();
}; };
} }

View File

@ -267,7 +267,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -296,7 +296,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -308,7 +308,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -340,7 +340,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -360,7 +360,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -368,7 +368,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));

View File

@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream; use futures::Stream;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -135,11 +136,11 @@ async fn generate(
); );
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", response.generated_text.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
@ -355,7 +356,8 @@ pub async fn run(
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(infer)); .layer(Extension(infer))
.layer(opentelemetry_tracing_layer());
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)

View File

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "1.65.0" channel = "1.67.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]