2022-10-18 13:19:03 +00:00
|
|
|
/// Multi shard Client
|
2022-10-08 10:30:12 +00:00
|
|
|
use crate::Result;
|
2022-10-11 14:50:54 +00:00
|
|
|
use crate::{Batch, Client, GeneratedText};
|
2022-10-08 10:30:12 +00:00
|
|
|
use futures::future::join_all;
|
|
|
|
use tokio::sync::{broadcast, mpsc};
|
|
|
|
use tonic::transport::Uri;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// List of all available commands that can be sent through the command channel
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
enum Command {
|
|
|
|
Generate(
|
|
|
|
Batch,
|
2022-10-11 14:50:54 +00:00
|
|
|
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
2022-10-08 10:30:12 +00:00
|
|
|
),
|
|
|
|
GenerateWithCache(
|
2022-10-11 14:50:54 +00:00
|
|
|
Vec<Batch>,
|
|
|
|
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
|
|
|
),
|
2022-10-08 10:30:12 +00:00
|
|
|
ClearCache(mpsc::Sender<Result<()>>),
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Tokio task that handles the communication with a single shard
|
|
|
|
///
|
|
|
|
/// We subscribe on a broadcast channel to receive commands that will be sent by
|
|
|
|
/// the ShardedClient.
|
|
|
|
///
|
|
|
|
/// Each command is fan out to all shards.
|
|
|
|
///
|
|
|
|
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
|
|
|
|
/// producer = the shards, single consumer = the ShardedClient).
|
2022-10-08 10:30:12 +00:00
|
|
|
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
|
|
|
while let Ok(message) = request_subscriber.recv().await {
|
|
|
|
match message {
|
|
|
|
Command::Generate(batch, response_tx) => {
|
|
|
|
let result = client.generate(batch).await;
|
2022-10-18 13:19:03 +00:00
|
|
|
// We can unwrap_or(()) here because the only error that can happen is if the
|
|
|
|
// receiver is dropped, which means that the ShardedClient already received a
|
|
|
|
// response from another shard
|
2022-10-08 10:30:12 +00:00
|
|
|
response_tx.try_send(result).unwrap_or(());
|
|
|
|
}
|
2022-10-11 14:50:54 +00:00
|
|
|
Command::GenerateWithCache(batches, response_tx) => {
|
|
|
|
let result = client.generate_with_cache(batches).await;
|
|
|
|
response_tx.try_send(result).unwrap_or(());
|
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
Command::ClearCache(response_tx) => {
|
|
|
|
let result = client.clear_cache().await;
|
|
|
|
response_tx.try_send(result).unwrap_or(());
|
|
|
|
}
|
|
|
|
};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Text Generation Inference gRPC multi client
|
2022-10-08 10:30:12 +00:00
|
|
|
pub struct ShardedClient {
|
2022-10-18 13:19:03 +00:00
|
|
|
_clients: Vec<Client>,
|
2022-10-08 10:30:12 +00:00
|
|
|
request_tx: broadcast::Sender<Command>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ShardedClient {
|
2022-10-18 13:19:03 +00:00
|
|
|
fn new(clients: Vec<Client>) -> Self {
|
|
|
|
// The broadcast channel to communicate with the shards
|
|
|
|
// We use a capacity of one as the shards are not asynchronous and can only process one
|
|
|
|
// command at a time
|
2022-10-08 10:30:12 +00:00
|
|
|
let (request_tx, _) = broadcast::channel(1);
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Spawn client tasks
|
|
|
|
for client in clients.iter() {
|
2022-10-08 10:30:12 +00:00
|
|
|
let request_subscriber = request_tx.subscribe();
|
2022-10-18 13:19:03 +00:00
|
|
|
tokio::spawn(client_task(client.clone(), request_subscriber));
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
Self {
|
|
|
|
_clients: clients,
|
|
|
|
request_tx,
|
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
|
|
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
2022-10-17 12:59:00 +00:00
|
|
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Get all uris/unix sockets from the master client
|
2022-10-08 10:30:12 +00:00
|
|
|
let uris = master_client.service_discovery().await.unwrap();
|
2022-10-18 13:19:03 +00:00
|
|
|
let futures = uris.into_iter().map(Client::connect_uds);
|
2022-10-17 12:59:00 +00:00
|
|
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
|
|
|
Ok(Self::new(clients?))
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Returns a client connected to the given uri
|
2022-10-17 12:59:00 +00:00
|
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
|
|
let master_client = Client::connect(uri).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
Self::from_master_client(master_client).await
|
|
|
|
}
|
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given unix socket
|
|
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
|
|
|
let master_client = Client::connect_uds(path).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
Self::from_master_client(master_client).await
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate one token for each request in the given batch
|
|
|
|
///
|
|
|
|
/// Returns a list of generated texts of request that met their stopping criteria
|
|
|
|
/// and the next cached batch
|
2022-10-11 14:50:54 +00:00
|
|
|
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Create a channel to receive the response from the shards
|
|
|
|
// We will only ever receive one message on this channel
|
2022-10-08 10:30:12 +00:00
|
|
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
|
|
self.request_tx
|
|
|
|
.send(Command::Generate(batch, response_tx))
|
|
|
|
.unwrap();
|
2022-10-18 13:19:03 +00:00
|
|
|
// As soon as we receive one response, we can return as all shards will return the same
|
2022-10-08 10:30:12 +00:00
|
|
|
response_rx.recv().await.unwrap()
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate one token for each request in the given cached batch
|
|
|
|
///
|
|
|
|
/// Returns a list of generated texts of request that met their stopping criteria
|
|
|
|
/// and the next cached batch
|
2022-10-08 10:30:12 +00:00
|
|
|
pub async fn generate_with_cache(
|
|
|
|
&self,
|
2022-10-11 14:50:54 +00:00
|
|
|
batches: Vec<Batch>,
|
|
|
|
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Create a channel to receive the response from the shards
|
|
|
|
// We will only ever receive one message on this channel
|
2022-10-11 14:50:54 +00:00
|
|
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
|
|
self.request_tx
|
|
|
|
.send(Command::GenerateWithCache(batches, response_tx))
|
|
|
|
.unwrap();
|
2022-10-18 13:19:03 +00:00
|
|
|
// As soon as we receive one response, we can return as all shards will return the same
|
2022-10-11 14:50:54 +00:00
|
|
|
response_rx.recv().await.unwrap()
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Clear the past generations cache
|
2022-10-08 10:30:12 +00:00
|
|
|
pub async fn clear_cache(&self) -> Result<()> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Create a channel to receive the response from the shards
|
|
|
|
// We will only ever receive one message on this channel
|
2022-10-08 10:30:12 +00:00
|
|
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
|
|
self.request_tx
|
|
|
|
.send(Command::ClearCache(response_tx))
|
|
|
|
.unwrap();
|
|
|
|
response_rx.recv().await.unwrap()
|
|
|
|
}
|
|
|
|
}
|