mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Update dependencies
This commit is contained in:
parent
7fa81a05b0
commit
1e5a30990b
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -23,6 +23,8 @@ jobs:
|
||||
toolchain: 1.65.0
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v1
|
||||
- name: Loading cache.
|
||||
uses: actions/cache@v2
|
||||
id: model_cache
|
||||
|
769
Cargo.lock
generated
769
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,6 @@
|
||||
FROM rust:1.65 as router-builder
|
||||
FROM rust:1.67 as router-builder
|
||||
|
||||
RUN apt-get update && apt-get install -y protobuf-compiler && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
@ -10,7 +12,7 @@ WORKDIR /usr/src/router
|
||||
|
||||
RUN cargo install --path .
|
||||
|
||||
FROM rust:1.65 as launcher-builder
|
||||
FROM rust:1.67 as launcher-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
14
README.md
14
README.md
@ -142,6 +142,20 @@ conda create -n text-generation-inference python=3.9
|
||||
conda activate text-generation-inference
|
||||
```
|
||||
|
||||
You may also need to install Protoc.
|
||||
|
||||
On Linux:
|
||||
|
||||
```shell
|
||||
apt install -y protobuf-compiler
|
||||
```
|
||||
|
||||
On MacOS, using Homebrew:
|
||||
|
||||
```shell
|
||||
brew install protobuf
|
||||
```
|
||||
|
||||
Then run:
|
||||
|
||||
```shell
|
||||
|
@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Launcher"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.2.3", features = ["termination"] }
|
||||
serde_json = "1.0.89"
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.2.5", features = ["termination"] }
|
||||
serde_json = "1.0.93"
|
||||
subprocess = "0.2.9"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
||||
serde = { version = "1.0.150", features = ["derive"] }
|
||||
reqwest = { version = "0.11.14", features = ["blocking", "json"] }
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
|
@ -165,7 +165,7 @@ fn main() -> ExitCode {
|
||||
"--port".to_string(),
|
||||
port.to_string(),
|
||||
"--master-shard-uds-path".to_string(),
|
||||
format!("{}-0", shard_uds_path),
|
||||
format!("{shard_uds_path}-0"),
|
||||
"--tokenizer-name".to_string(),
|
||||
model_id,
|
||||
];
|
||||
@ -269,7 +269,7 @@ fn shard_manager(
|
||||
_shutdown_sender: mpsc::Sender<()>,
|
||||
) {
|
||||
// Get UDS path
|
||||
let uds_string = format!("{}-{}", uds_path, rank);
|
||||
let uds_string = format!("{uds_path}-{rank}");
|
||||
let uds = Path::new(&uds_string);
|
||||
// Clean previous runs
|
||||
fs::remove_file(uds).unwrap_or_default();
|
||||
|
@ -15,23 +15,24 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
async-stream = "0.3.3"
|
||||
axum = { version = "0.6.4", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.9.0"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
futures = "0.3.24"
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
futures = "0.3.26"
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
parking_lot = "0.12.1"
|
||||
rand = "0.8.5"
|
||||
serde = "1.0.145"
|
||||
serde_json = "1.0.85"
|
||||
thiserror = "1.0.37"
|
||||
tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.93"
|
||||
thiserror = "1.0.38"
|
||||
tokenizers = "0.13.2"
|
||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.11"
|
||||
tracing = "0.1.36"
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
|
||||
|
@ -6,13 +6,13 @@ edition = "2021"
|
||||
[dependencies]
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
prost = "^0.9"
|
||||
prost = "^0.11"
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.21", features = ["sync"] }
|
||||
tonic = "^0.6"
|
||||
tokio = { version = "^1.25", features = ["sync"] }
|
||||
tonic = "^0.8"
|
||||
tower = "^0.4"
|
||||
tracing = "^0.1"
|
||||
tracing-error = "^0.2"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.6.2"
|
||||
tonic-build = "^0.8"
|
||||
|
@ -9,7 +9,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.out_dir("src/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile(&["../../proto/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -5,6 +5,6 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
opentelemetry = "0.18.0"
|
||||
tonic = "^0.6"
|
||||
tonic = "^0.8"
|
||||
tracing = "^0.1"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
|
@ -33,7 +33,7 @@ impl<'a> Injector for MetadataInjector<'a> {
|
||||
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
||||
fn set(&mut self, key: &str, value: String) {
|
||||
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
|
||||
if let Ok(val) = value.parse() {
|
||||
self.0.insert(key, val);
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
use clap::Parser;
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
@ -130,21 +131,26 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let tracer =
|
||||
opentelemetry_otlp::new_pipeline()
|
||||
let tracer = opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(
|
||||
opentelemetry_otlp::new_exporter()
|
||||
.tonic()
|
||||
.with_endpoint(otlp_endpoint),
|
||||
)
|
||||
.with_trace_config(trace::config().with_resource(Resource::new(vec![
|
||||
KeyValue::new("service.name", "text-generation-inference.router"),
|
||||
])))
|
||||
.with_trace_config(
|
||||
trace::config()
|
||||
.with_resource(Resource::new(vec![KeyValue::new(
|
||||
"service.name",
|
||||
"text-generation-inference.router",
|
||||
)]))
|
||||
.with_sampler(Sampler::AlwaysOn),
|
||||
)
|
||||
.install_batch(opentelemetry::runtime::Tokio);
|
||||
|
||||
if let Ok(tracer) = tracer {
|
||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||
axum_tracing_opentelemetry::init_propagator().unwrap();
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -267,7 +267,7 @@ mod tests {
|
||||
state.append(default_entry());
|
||||
state.append(default_entry());
|
||||
|
||||
let (entries, batch) = state.next_batch(None, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -296,7 +296,7 @@ mod tests {
|
||||
state.append(default_entry());
|
||||
state.append(default_entry());
|
||||
|
||||
let (entries, batch) = state.next_batch(None, 1).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -308,7 +308,7 @@ mod tests {
|
||||
|
||||
state.append(default_entry());
|
||||
|
||||
let (entries, batch) = state.next_batch(None, 3).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -340,7 +340,7 @@ mod tests {
|
||||
queue.append(default_entry());
|
||||
queue.append(default_entry());
|
||||
|
||||
let (entries, batch) = queue.next_batch(None, 2).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -360,7 +360,7 @@ mod tests {
|
||||
queue.append(default_entry());
|
||||
queue.append(default_entry());
|
||||
|
||||
let (entries, batch) = queue.next_batch(None, 1).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -368,7 +368,7 @@ mod tests {
|
||||
|
||||
queue.append(default_entry());
|
||||
|
||||
let (entries, batch) = queue.next_batch(None, 3).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
|
@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use futures::Stream;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
@ -135,11 +136,11 @@ async fn generate(
|
||||
);
|
||||
|
||||
// Tracing metadata
|
||||
span.record("total_time", format!("{:?}", total_time));
|
||||
span.record("validation_time", format!("{:?}", validation_time));
|
||||
span.record("queue_time", format!("{:?}", queue_time));
|
||||
span.record("inference_time", format!("{:?}", inference_time));
|
||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||
span.record("total_time", format!("{total_time:?}"));
|
||||
span.record("validation_time", format!("{validation_time:?}"));
|
||||
span.record("queue_time", format!("{queue_time:?}"));
|
||||
span.record("inference_time", format!("{inference_time:?}"));
|
||||
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||
tracing::info!("Output: {}", response.generated_text.text);
|
||||
|
||||
@ -355,7 +356,8 @@ pub async fn run(
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/", get(health))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(infer));
|
||||
.layer(Extension(infer))
|
||||
.layer(opentelemetry_tracing_layer());
|
||||
|
||||
// Run server
|
||||
axum::Server::bind(&addr)
|
||||
|
@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "1.65.0"
|
||||
channel = "1.67.0"
|
||||
components = ["rustfmt", "clippy"]
|
Loading…
Reference in New Issue
Block a user