2022-10-08 10:30:12 +00:00
|
|
|
use bloom_inference_client::ShardedClient;
|
2022-10-14 13:56:21 +00:00
|
|
|
use std::net::SocketAddr;
|
2022-10-17 12:59:00 +00:00
|
|
|
use text_generation_router::server;
|
2022-10-11 14:50:54 +00:00
|
|
|
use tokenizers::Tokenizer;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-11 14:50:54 +00:00
|
|
|
fn main() -> Result<(), std::io::Error> {
|
|
|
|
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
|
|
|
|
|
|
|
|
tokio::runtime::Builder::new_multi_thread()
|
|
|
|
.enable_all()
|
|
|
|
.build()
|
|
|
|
.unwrap()
|
|
|
|
.block_on(async {
|
|
|
|
tracing_subscriber::fmt::init();
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
|
|
|
|
.await
|
|
|
|
.expect("Could not connect to server");
|
2022-10-11 14:50:54 +00:00
|
|
|
sharded_client
|
|
|
|
.clear_cache()
|
|
|
|
.await
|
|
|
|
.expect("Unable to clear cache");
|
|
|
|
tracing::info!("Connected");
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-14 13:56:21 +00:00
|
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-11 16:14:39 +00:00
|
|
|
server::run(sharded_client, tokenizer, addr).await;
|
|
|
|
Ok(())
|
2022-10-11 14:50:54 +00:00
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|