mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Update dependencies
This commit is contained in:
parent
7fa81a05b0
commit
1e5a30990b
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -23,6 +23,8 @@ jobs:
|
|||||||
toolchain: 1.65.0
|
toolchain: 1.65.0
|
||||||
override: true
|
override: true
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
|
- name: Install Protoc
|
||||||
|
uses: arduino/setup-protoc@v1
|
||||||
- name: Loading cache.
|
- name: Loading cache.
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v2
|
||||||
id: model_cache
|
id: model_cache
|
||||||
|
769
Cargo.lock
generated
769
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,6 @@
|
|||||||
FROM rust:1.65 as router-builder
|
FROM rust:1.67 as router-builder
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y protobuf-compiler && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
@ -10,7 +12,7 @@ WORKDIR /usr/src/router
|
|||||||
|
|
||||||
RUN cargo install --path .
|
RUN cargo install --path .
|
||||||
|
|
||||||
FROM rust:1.65 as launcher-builder
|
FROM rust:1.67 as launcher-builder
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
14
README.md
14
README.md
@ -142,6 +142,20 @@ conda create -n text-generation-inference python=3.9
|
|||||||
conda activate text-generation-inference
|
conda activate text-generation-inference
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You may also need to install Protoc.
|
||||||
|
|
||||||
|
On Linux:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
apt install -y protobuf-compiler
|
||||||
|
```
|
||||||
|
|
||||||
|
On MacOS, using Homebrew:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
brew install protobuf
|
||||||
|
```
|
||||||
|
|
||||||
Then run:
|
Then run:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
|
|||||||
description = "Text Generation Launcher"
|
description = "Text Generation Launcher"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
ctrlc = { version = "3.2.3", features = ["termination"] }
|
ctrlc = { version = "3.2.5", features = ["termination"] }
|
||||||
serde_json = "1.0.89"
|
serde_json = "1.0.93"
|
||||||
subprocess = "0.2.9"
|
subprocess = "0.2.9"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
reqwest = { version = "0.11.14", features = ["blocking", "json"] }
|
||||||
serde = { version = "1.0.150", features = ["derive"] }
|
serde = { version = "1.0.152", features = ["derive"] }
|
||||||
|
@ -165,7 +165,7 @@ fn main() -> ExitCode {
|
|||||||
"--port".to_string(),
|
"--port".to_string(),
|
||||||
port.to_string(),
|
port.to_string(),
|
||||||
"--master-shard-uds-path".to_string(),
|
"--master-shard-uds-path".to_string(),
|
||||||
format!("{}-0", shard_uds_path),
|
format!("{shard_uds_path}-0"),
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
model_id,
|
model_id,
|
||||||
];
|
];
|
||||||
@ -269,7 +269,7 @@ fn shard_manager(
|
|||||||
_shutdown_sender: mpsc::Sender<()>,
|
_shutdown_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
// Get UDS path
|
// Get UDS path
|
||||||
let uds_string = format!("{}-{}", uds_path, rank);
|
let uds_string = format!("{uds_path}-{rank}");
|
||||||
let uds = Path::new(&uds_string);
|
let uds = Path::new(&uds_string);
|
||||||
// Clean previous runs
|
// Clean previous runs
|
||||||
fs::remove_file(uds).unwrap_or_default();
|
fs::remove_file(uds).unwrap_or_default();
|
||||||
|
@ -15,23 +15,24 @@ path = "src/main.rs"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = "0.3.3"
|
async-stream = "0.3.3"
|
||||||
axum = { version = "0.6.4", features = ["json"] }
|
axum = { version = "0.6.4", features = ["json"] }
|
||||||
|
axum-tracing-opentelemetry = "0.9.0"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.26"
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
serde = "1.0.145"
|
serde = "1.0.152"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.93"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.38"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.2"
|
||||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.11"
|
tokio-stream = "0.1.11"
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
|
||||||
|
@ -6,13 +6,13 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
prost = "^0.9"
|
prost = "^0.11"
|
||||||
thiserror = "^1.0"
|
thiserror = "^1.0"
|
||||||
tokio = { version = "^1.21", features = ["sync"] }
|
tokio = { version = "^1.25", features = ["sync"] }
|
||||||
tonic = "^0.6"
|
tonic = "^0.8"
|
||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
tracing-error = "^0.2"
|
tracing-error = "^0.2"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = "0.6.2"
|
tonic-build = "^0.8"
|
||||||
|
@ -9,7 +9,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.out_dir("src/pb")
|
.out_dir("src/pb")
|
||||||
.include_file("mod.rs")
|
.include_file("mod.rs")
|
||||||
.compile(&["../../proto/generate.proto"], &["../../proto"])
|
.compile(&["../../proto/generate.proto"], &["../../proto"])
|
||||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,6 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
opentelemetry = "0.18.0"
|
opentelemetry = "0.18.0"
|
||||||
tonic = "^0.6"
|
tonic = "^0.8"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
|
@ -33,7 +33,7 @@ impl<'a> Injector for MetadataInjector<'a> {
|
|||||||
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
||||||
fn set(&mut self, key: &str, value: String) {
|
fn set(&mut self, key: &str, value: String) {
|
||||||
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
||||||
if let Ok(val) = tonic::metadata::MetadataValue::from_str(&value) {
|
if let Ok(val) = value.parse() {
|
||||||
self.0.insert(key, val);
|
self.0.insert(key, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
@ -130,21 +131,26 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||||
|
|
||||||
let tracer =
|
let tracer = opentelemetry_otlp::new_pipeline()
|
||||||
opentelemetry_otlp::new_pipeline()
|
.tracing()
|
||||||
.tracing()
|
.with_exporter(
|
||||||
.with_exporter(
|
opentelemetry_otlp::new_exporter()
|
||||||
opentelemetry_otlp::new_exporter()
|
.tonic()
|
||||||
.tonic()
|
.with_endpoint(otlp_endpoint),
|
||||||
.with_endpoint(otlp_endpoint),
|
)
|
||||||
)
|
.with_trace_config(
|
||||||
.with_trace_config(trace::config().with_resource(Resource::new(vec![
|
trace::config()
|
||||||
KeyValue::new("service.name", "text-generation-inference.router"),
|
.with_resource(Resource::new(vec![KeyValue::new(
|
||||||
])))
|
"service.name",
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
"text-generation-inference.router",
|
||||||
|
)]))
|
||||||
|
.with_sampler(Sampler::AlwaysOn),
|
||||||
|
)
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio);
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
if let Ok(tracer) = tracer {
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||||
|
axum_tracing_opentelemetry::init_propagator().unwrap();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +267,7 @@ mod tests {
|
|||||||
state.append(default_entry());
|
state.append(default_entry());
|
||||||
state.append(default_entry());
|
state.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = state.next_batch(None, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -296,7 +296,7 @@ mod tests {
|
|||||||
state.append(default_entry());
|
state.append(default_entry());
|
||||||
state.append(default_entry());
|
state.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = state.next_batch(None, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -308,7 +308,7 @@ mod tests {
|
|||||||
|
|
||||||
state.append(default_entry());
|
state.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = state.next_batch(None, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -340,7 +340,7 @@ mod tests {
|
|||||||
queue.append(default_entry());
|
queue.append(default_entry());
|
||||||
queue.append(default_entry());
|
queue.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = queue.next_batch(None, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -360,7 +360,7 @@ mod tests {
|
|||||||
queue.append(default_entry());
|
queue.append(default_entry());
|
||||||
queue.append(default_entry());
|
queue.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = queue.next_batch(None, 1).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -368,7 +368,7 @@ mod tests {
|
|||||||
|
|
||||||
queue.append(default_entry());
|
queue.append(default_entry());
|
||||||
|
|
||||||
let (entries, batch) = queue.next_batch(None, 3).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
|
@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
|
|||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
@ -135,11 +136,11 @@ async fn generate(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
span.record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{total_time:?}"));
|
||||||
span.record("validation_time", format!("{:?}", validation_time));
|
span.record("validation_time", format!("{validation_time:?}"));
|
||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{queue_time:?}"));
|
||||||
span.record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{inference_time:?}"));
|
||||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||||
span.record("seed", format!("{:?}", response.generated_text.seed));
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
@ -355,7 +356,8 @@ pub async fn run(
|
|||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(infer));
|
.layer(Extension(infer))
|
||||||
|
.layer(opentelemetry_tracing_layer());
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
channel = "1.65.0"
|
channel = "1.67.0"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
Loading…
Reference in New Issue
Block a user