feat: Add arguments to CLI

This commit is contained in:
Olivier Dehaene 2022-10-17 18:27:33 +02:00
parent 5e5d8766a2
commit 92c1ecd008
9 changed files with 163 additions and 35 deletions

View File

@ -9,7 +9,7 @@ WORKDIR /usr/src/router
RUN cargo install --path . RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
ENV LANG=C.UTF-8 \ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \

View File

@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO: ## TODO:
- [ ] Add batching args to router CLI
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated - [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests - [ ] Add tests
- [ ] Add shutdown logic in router and server - [ ] Add shutdown logic in router and server

102
router/Cargo.lock generated
View File

@ -253,6 +253,43 @@ dependencies = [
"vec_map", "vec_map",
] ]
[[package]]
name = "clap"
version = "4.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"clap_lex",
"once_cell",
"strsim 0.10.0",
"termcolor",
]
[[package]]
name = "clap_derive"
version = "4.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8"
dependencies = [
"os_str_bytes",
]
[[package]] [[package]]
name = "console" name = "console"
version = "0.15.2" version = "0.15.2"
@ -701,6 +738,12 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.1.19" version = "0.1.19"
@ -1136,6 +1179,12 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "os_str_bytes"
version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.1" version = "0.12.1"
@ -1225,6 +1274,30 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.46" version = "1.0.46"
@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5" checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck", "heck 0.3.3",
"itertools 0.10.5", "itertools 0.10.5",
"lazy_static", "lazy_static",
"log", "log",
@ -1601,6 +1674,12 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.101" version = "1.0.101"
@ -1643,6 +1722,15 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "termcolor"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "terminal_size" name = "terminal_size"
version = "0.1.17" version = "0.1.17"
@ -1659,6 +1747,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"bloom-inference-client", "bloom-inference-client",
"clap 4.0.15",
"futures", "futures",
"parking_lot", "parking_lot",
"serde", "serde",
@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"cached-path", "cached-path",
"clap", "clap 2.34.0",
"derive_builder", "derive_builder",
"dirs", "dirs",
"esaxx-rs", "esaxx-rs",
@ -2251,6 +2340,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"

View File

@ -2,6 +2,8 @@
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver"
[lib] [lib]
path = "src/lib.rs" path = "src/lib.rs"
@ -13,6 +15,7 @@ path = "src/main.rs"
[dependencies] [dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] } axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" } bloom-inference-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.24"
parking_lot = "0.12.1" parking_lot = "0.12.1"
serde = "1.0.145" serde = "1.0.145"

View File

@ -27,7 +27,7 @@ impl From<InferError> for (StatusCode, String) {
} }
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct Batcher { pub struct Batcher {
db: Db, db: Db,
shared: Arc<Shared>, shared: Arc<Shared>,
} }
@ -37,13 +37,13 @@ struct Shared {
} }
impl Batcher { impl Batcher {
pub(crate) fn new(client: ShardedClient) -> Self { pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
let db = Db::new(); let db = Db::new();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
tokio::spawn(batching_task(client, db.clone(), shared.clone())); tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
Self { db, shared } Self { db, shared }
} }
@ -70,40 +70,46 @@ impl Batcher {
} }
} }
async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) { async fn batching_task(max_batch_size: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>) {
let limit_min_batch_size = (max_batch_size / 2) as u32;
loop { loop {
shared.batching_task.notified().await; shared.batching_task.notified().await;
if let Some(batch) = db.next_batch(32) { if let Some(batch) = db.next_batch(max_batch_size) {
let request_ids = batch.requests.iter().map(|req| req.id).collect(); let request_ids = batch.requests.iter().map(|req| req.id).collect();
let mut cached_batch = match batch.size { let mut cached_batch = match batch.size {
size if size > 16 => { size if size > limit_min_batch_size => {
wrap_future(client.generate_until_finished(batch), request_ids, &db).await wrap_future(client.generate_until_finished(batch), request_ids, &db).await
} }
_ => wrap_future(client.generate(batch), request_ids, &db).await, _ => wrap_future(client.generate(batch), request_ids, &db).await,
}; };
while let Some(batch) = cached_batch { while let Some(batch) = cached_batch {
let batch_size = batch.size; let mut current_batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect(); let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch]; let mut batches = vec![batch];
if batch_size <= 16 { if current_batch_size <= limit_min_batch_size {
if let Some(new_batch) = db.next_batch_minimum_size(16, 48) { if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
let new_batch_request_ids = let new_batch_request_ids =
new_batch.requests.iter().map(|req| req.id).collect(); new_batch.requests.iter().map(|req| req.id).collect();
let new_cached_batch = let new_cached_batch =
wrap_future(client.generate(new_batch), new_batch_request_ids, &db) wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
.await; .await;
if let Some(new_cached_batch) = new_cached_batch { if let Some(new_cached_batch) = new_cached_batch {
current_batch_size += new_cached_batch.size;
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch); batches.push(new_cached_batch);
} }
} }
} }
cached_batch = match batch_size { cached_batch = match current_batch_size {
size if size > 16 => { size if size > limit_min_batch_size => {
wrap_future( wrap_future(
client.generate_until_finished_with_cache(batches), client.generate_until_finished_with_cache(batches),
request_ids, request_ids,

View File

@ -1,8 +1,8 @@
mod batcher; mod batcher;
mod db; mod db;
pub mod server;
mod validation; mod validation;
pub mod server;
use batcher::Batcher;
use db::{Db, Entry}; use db::{Db, Entry};
use batcher::Batcher;
use validation::Validation; use validation::Validation;

View File

@ -1,10 +1,36 @@
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use std::net::SocketAddr; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use clap::Parser;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "32", long, short, env)]
max_batch_size: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
}
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap(); // Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_batch_size,
port,
shard_uds_path,
tokenizer_name,
} = args;
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string()) let sharded_client = ShardedClient::connect_uds(shard_uds_path)
.await .await
.expect("Could not connect to server"); .expect("Could not connect to server");
sharded_client sharded_client
@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
server::run(sharded_client, tokenizer, addr).await; server::run(max_batch_size, sharded_client, tokenizer, addr).await;
Ok(()) Ok(())
}) })
} }

View File

@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest {
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> { async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
state state
.infer .batcher
.infer( .infer(
1, 1,
GenerateRequest { GenerateRequest {
@ -97,7 +97,7 @@ async fn generate(
}) })
.await?; .await?;
let generated_text = state.infer.infer(input_length, validated_request).await?; let generated_text = state.batcher.infer(input_length, validated_request).await?;
tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record( tracing::Span::current().record(
@ -114,18 +114,14 @@ async fn generate(
#[derive(Clone)] #[derive(Clone)]
struct ServerState { struct ServerState {
validation: Validation, validation: Validation,
infer: Batcher, batcher: Batcher,
} }
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
client.clear_cache().await.expect("Unable to clear cache"); let batcher = Batcher::new(client, max_batch_size);
tracing::info!("Connected");
let infer = Batcher::new(client);
let validation = Validation::new(tokenizer); let validation = Validation::new(tokenizer);
let shared_state = ServerState { validation, infer }; let shared_state = ServerState { validation, batcher };
let app = Router::new() let app = Router::new()
.route("/generate", post(generate)) .route("/generate", post(generate))

View File

@ -14,7 +14,7 @@ pub enum ValidationError {
TopK, TopK,
#[error("Max New Tokens must be < 512")] #[error("Max New Tokens must be < 512")]
MaxNewTokens, MaxNewTokens,
#[error("Inputs must have less than 512 tokens. Given: {0}")] #[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize), InputLength(usize),
} }
@ -30,7 +30,7 @@ type ValidationRequest = (
); );
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Validation { pub struct Validation {
sender: mpsc::Sender<ValidationRequest>, sender: mpsc::Sender<ValidationRequest>,
} }
@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap(); let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
let input_length = inputs.len(); let input_length = inputs.len();
if input_length > 512 { if input_length > 1000 {
response_tx response_tx
.send(Err(ValidationError::InputLength(input_length))) .send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(()); .unwrap_or(());