mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Use axum
This commit is contained in:
parent
e86ecbac63
commit
39df4d9975
202
router/Cargo.lock
generated
202
router/Cargo.lock
generated
@ -81,6 +81,53 @@ version = "1.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.5.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"axum-core",
|
||||||
|
"bitflags",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"hyper",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower",
|
||||||
|
"tower-http",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"mime",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
@ -106,10 +153,10 @@ dependencies = [
|
|||||||
name = "bloom-inference"
|
name = "bloom-inference"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"axum",
|
||||||
"bloom-inference-client",
|
"bloom-inference-client",
|
||||||
"futures",
|
"futures",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"poem",
|
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
@ -661,31 +708,6 @@ version = "0.12.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "headers"
|
|
||||||
version = "0.3.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
|
|
||||||
dependencies = [
|
|
||||||
"base64",
|
|
||||||
"bitflags",
|
|
||||||
"bytes",
|
|
||||||
"headers-core",
|
|
||||||
"http",
|
|
||||||
"httpdate",
|
|
||||||
"mime",
|
|
||||||
"sha1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "headers-core"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
|
|
||||||
dependencies = [
|
|
||||||
"http",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
@ -726,6 +748,12 @@ dependencies = [
|
|||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "http-range-header"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "httparse"
|
name = "httparse"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
@ -941,6 +969,12 @@ version = "0.1.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
|
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
@ -1201,65 +1235,12 @@ version = "0.3.25"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "poem"
|
|
||||||
version = "1.3.45"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2992ba72908e36200671c0f3a692992ced894b3b2bbe2b2dc6dfbffea6e2c85a"
|
|
||||||
dependencies = [
|
|
||||||
"async-trait",
|
|
||||||
"bytes",
|
|
||||||
"futures-util",
|
|
||||||
"headers",
|
|
||||||
"http",
|
|
||||||
"hyper",
|
|
||||||
"mime",
|
|
||||||
"parking_lot",
|
|
||||||
"percent-encoding",
|
|
||||||
"pin-project-lite",
|
|
||||||
"poem-derive",
|
|
||||||
"regex",
|
|
||||||
"rfc7239",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"serde_urlencoded",
|
|
||||||
"smallvec",
|
|
||||||
"thiserror",
|
|
||||||
"tokio",
|
|
||||||
"tokio-stream",
|
|
||||||
"tokio-util 0.7.4",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "poem-derive"
|
|
||||||
version = "1.3.45"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "9f535d4331a22610b98ca48f98bae9bda0c654da89b9ae10a1830fa9edfd8f36"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro-crate",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.16"
|
version = "0.2.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "proc-macro-crate"
|
|
||||||
version = "1.2.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
|
|
||||||
dependencies = [
|
|
||||||
"once_cell",
|
|
||||||
"thiserror",
|
|
||||||
"toml",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.46"
|
version = "1.0.46"
|
||||||
@ -1479,15 +1460,6 @@ dependencies = [
|
|||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rfc7239"
|
|
||||||
version = "0.1.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "087317b3cf7eb481f13bd9025d729324b7cd068d6f470e2d76d049e191f5ba47"
|
|
||||||
dependencies = [
|
|
||||||
"uncased",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
@ -1576,17 +1548,6 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sha1"
|
|
||||||
version = "0.10.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"cpufeatures",
|
|
||||||
"digest",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
@ -1667,6 +1628,12 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sync_wrapper"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tar"
|
name = "tar"
|
||||||
version = "0.4.38"
|
version = "0.4.38"
|
||||||
@ -1890,15 +1857,6 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "toml"
|
|
||||||
version = "0.5.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7"
|
|
||||||
dependencies = [
|
|
||||||
"serde",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tonic"
|
name = "tonic"
|
||||||
version = "0.6.2"
|
version = "0.6.2"
|
||||||
@ -1962,6 +1920,25 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-range-header",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
@ -2065,15 +2042,6 @@ version = "1.15.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "uncased"
|
|
||||||
version = "0.9.7"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622"
|
|
||||||
dependencies = [
|
|
||||||
"version_check",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-bidi"
|
name = "unicode-bidi"
|
||||||
version = "0.3.8"
|
version = "0.3.8"
|
||||||
|
@ -3,13 +3,11 @@ name = "bloom-inference"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||||
bloom-inference-client = { path = "client" }
|
bloom-inference-client = { path = "client" }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
poem = "1.3.45"
|
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
|
@ -4,12 +4,12 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "0.3.24"
|
futures = "^0.3"
|
||||||
#grpc-error-details = { path = "../../grpc-error-details" }
|
#grpc-error-details = { path = "../../grpc-error-details" }
|
||||||
#grpc-metadata = { path = "../../grpc-metadata" }
|
#grpc-metadata = { path = "../../grpc-metadata" }
|
||||||
prost = "^0.9"
|
prost = "^0.9"
|
||||||
thiserror = "1.0.37"
|
thiserror = "^1.0"
|
||||||
tokio = { version = "1.21.2", features = ["sync"] }
|
tokio = { version = "^1.21", features = ["sync"] }
|
||||||
tonic = "^0.6"
|
tonic = "^0.6"
|
||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use poem::listener::TcpListener;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let addr = "127.0.0.1:3000".to_string();
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||||
let listener = TcpListener::bind(addr);
|
|
||||||
|
|
||||||
server::run(sharded_client, tokenizer, listener).await
|
server::run(sharded_client, tokenizer, addr).await;
|
||||||
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use axum::{Router, Json};
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::routing::post;
|
||||||
use crate::{Batcher, ShardedClient, Validation};
|
use crate::{Batcher, ShardedClient, Validation};
|
||||||
use poem::http::StatusCode;
|
|
||||||
use poem::listener::TcpListener;
|
|
||||||
use poem::middleware::AddData;
|
|
||||||
use poem::web::{Data, Json};
|
|
||||||
use poem::{handler, post, EndpointExt, Route, Server};
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest {
|
|||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[handler]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
#[instrument(skip(validation, infer), fields(time, time_per_token))]
|
|
||||||
async fn generate(
|
async fn generate(
|
||||||
validation: Data<&Validation>,
|
state: Extension<ServerState>,
|
||||||
infer: Data<&Batcher>,
|
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> poem::Result<Json<serde_json::Value>> {
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let (input_length, validated_request) = match validation
|
let (input_length, validated_request) = match state.validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: req.inputs.clone(),
|
inputs: req.inputs.clone(),
|
||||||
parameters: req.parameters.clone(),
|
parameters: req.parameters.clone(),
|
||||||
})
|
})
|
||||||
.await {
|
.await {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
Err(_) => return Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
|
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
};
|
};
|
||||||
|
|
||||||
let output = infer.infer(input_length, validated_request).await;
|
let output = state.infer.infer(input_length, validated_request).await;
|
||||||
|
|
||||||
match output {
|
match output {
|
||||||
Ok(generated_text) => {
|
Ok(generated_text) => {
|
||||||
@ -94,15 +92,21 @@ async fn generate(
|
|||||||
"generated_text": generated_text,
|
"generated_text": generated_text,
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ServerState {
|
||||||
|
validation: Validation,
|
||||||
|
infer: Batcher,
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
listener: TcpListener<String>,
|
addr: SocketAddr,
|
||||||
) -> Result<(), std::io::Error> {
|
) {
|
||||||
client.clear_cache().await.expect("Unable to clear cache");
|
client.clear_cache().await.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
@ -110,10 +114,13 @@ pub async fn run(
|
|||||||
|
|
||||||
let validation = Validation::new(tokenizer);
|
let validation = Validation::new(tokenizer);
|
||||||
|
|
||||||
let app = Route::new()
|
let shared_state = ServerState {
|
||||||
.at("/generate", post(generate))
|
validation,
|
||||||
.with(AddData::new(validation))
|
infer,
|
||||||
.with(AddData::new(infer));
|
};
|
||||||
|
|
||||||
Server::new(listener).run(app).await
|
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
|
||||||
|
|
||||||
|
axum::Server::bind(&addr)
|
||||||
|
.serve(app.into_make_service()).await.unwrap();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user