Update dependencies

This commit is contained in:
OlivierDehaene 2023-02-10 12:14:29 +01:00
parent 7fa81a05b0
commit 1e5a30990b
15 changed files with 440 additions and 456 deletions

View File

@ -23,6 +23,8 @@ jobs:
toolchain: 1.65.0
override: true
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Loading cache.
uses: actions/cache@v2
id: model_cache

769
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,6 @@
FROM rust:1.65 as router-builder
FROM rust:1.67 as router-builder
RUN apt-get update && apt-get install -y protobuf-compiler && rm -rf /var/lib/apt/lists/*
WORKDIR /usr/src
@ -10,7 +12,7 @@ WORKDIR /usr/src/router
RUN cargo install --path .
FROM rust:1.65 as launcher-builder
FROM rust:1.67 as launcher-builder
WORKDIR /usr/src

View File

@ -142,6 +142,20 @@ conda create -n text-generation-inference python=3.9
conda activate text-generation-inference
```
You may also need to install Protoc.
On Linux:
```shell
apt install -y protobuf-compiler
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run:
```shell

View File

@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = { version = "3.2.3", features = ["termination"] }
serde_json = "1.0.89"
clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] }
serde_json = "1.0.93"
subprocess = "0.2.9"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = { version = "1.0.150", features = ["derive"] }
reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.152", features = ["derive"] }

View File

@ -165,7 +165,7 @@ fn main() -> ExitCode {
"--port".to_string(),
port.to_string(),
"--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path),
format!("{shard_uds_path}-0"),
"--tokenizer-name".to_string(),
model_id,
];
@ -269,7 +269,7 @@ fn shard_manager(
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{}-{}", uds_path, rank);
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();

View File

@ -15,23 +15,24 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.3"
axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26"
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tracing = "0.1.36"
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -6,13 +6,13 @@ edition = "2021"
[dependencies]
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.9"
prost = "^0.11"
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }
tonic = "^0.6"
tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8"
tower = "^0.4"
tracing = "^0.1"
tracing-error = "^0.2"
[build-dependencies]
tonic-build = "0.6.2"
tonic-build = "^0.8"

View File

@ -9,7 +9,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.out_dir("src/pb")
.include_file("mod.rs")
.compile(&["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

View File

@ -5,6 +5,6 @@ edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.6"
tonic = "^0.8"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"

View File

@ -33,7 +33,7 @@ impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
if let Ok(val) = value.parse() {
self.0.insert(key, val);
}
}

View File

@ -2,6 +2,7 @@
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
@ -130,21 +131,26 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer =
opentelemetry_otlp::new_pipeline()
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(trace::config().with_resource(Resource::new(vec![
KeyValue::new("service.name", "text-generation-inference.router"),
])))
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"text-generation-inference.router",
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
axum_tracing_opentelemetry::init_propagator().unwrap();
};
}

View File

@ -267,7 +267,7 @@ mod tests {
state.append(default_entry());
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -296,7 +296,7 @@ mod tests {
state.append(default_entry());
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -308,7 +308,7 @@ mod tests {
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
@ -340,7 +340,7 @@ mod tests {
queue.append(default_entry());
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 2).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -360,7 +360,7 @@ mod tests {
queue.append(default_entry());
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 1).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -368,7 +368,7 @@ mod tests {
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 3).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));

View File

@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream;
use std::convert::Infallible;
use std::net::SocketAddr;
@ -135,11 +136,11 @@ async fn generate(
);
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text);
@ -355,7 +356,8 @@ pub async fn run(
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
.route("/health", get(health))
.layer(Extension(infer));
.layer(Extension(infer))
.layer(opentelemetry_tracing_layer());
// Run server
axum::Server::bind(&addr)

View File

@ -1,3 +1,3 @@
[toolchain]
channel = "1.65.0"
channel = "1.67.0"
components = ["rustfmt", "clippy"]