text-generation-inference/router/src/main.rs

95 lines
2.9 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Text Generation Inference webserver entrypoint
2022-10-08 10:30:12 +00:00
use bloom_inference_client::ShardedClient;
2022-10-18 13:19:03 +00:00
use clap::Parser;
2022-10-17 16:27:33 +00:00
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
2022-10-18 13:19:03 +00:00
use std::time::Duration;
2022-10-17 12:59:00 +00:00
use text_generation_router::server;
use tokenizers::Tokenizer;
2022-10-17 16:27:33 +00:00
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
2022-10-18 13:19:03 +00:00
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
2022-10-17 16:27:33 +00:00
max_batch_size: usize,
2022-10-18 13:19:03 +00:00
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
2022-10-17 16:27:33 +00:00
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
2022-10-18 13:19:03 +00:00
master_shard_uds_path: String,
2022-10-17 16:27:33 +00:00
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
2022-10-18 13:19:03 +00:00
#[clap(default_value = "2", long, env)]
validation_workers: usize,
2022-10-17 16:27:33 +00:00
}
2022-10-08 10:30:12 +00:00
fn main() -> Result<(), std::io::Error> {
2022-10-17 16:27:33 +00:00
// Get args
let args = Args::parse();
2022-10-18 13:19:03 +00:00
// Pattern match configuration
2022-10-17 16:27:33 +00:00
let Args {
2022-10-18 13:19:03 +00:00
max_concurrent_requests,
max_input_length,
2022-10-17 16:27:33 +00:00
max_batch_size,
2022-10-18 13:19:03 +00:00
max_waiting_time,
2022-10-17 16:27:33 +00:00
port,
2022-10-18 13:19:03 +00:00
master_shard_uds_path,
2022-10-17 16:27:33 +00:00
tokenizer_name,
2022-10-18 13:19:03 +00:00
validation_workers,
2022-10-17 16:27:33 +00:00
} = args;
2022-10-18 13:19:03 +00:00
if validation_workers == 1 {
panic!("validation_workers must be > 0");
}
let max_waiting_time = Duration::from_secs(max_waiting_time);
2022-10-17 16:27:33 +00:00
2022-10-18 13:19:03 +00:00
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
2022-10-17 16:27:33 +00:00
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
2022-10-18 13:19:03 +00:00
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
tracing_subscriber::fmt::init();
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
2022-10-17 12:59:00 +00:00
.await
.expect("Could not connect to server");
2022-10-18 13:19:03 +00:00
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache()
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Binds on localhost
2022-10-17 16:27:33 +00:00
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Run server
server::run(
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
sharded_client,
tokenizer,
validation_workers,
addr,
)
.await;
2022-10-11 16:14:39 +00:00
Ok(())
})
2022-10-08 10:30:12 +00:00
}