text-generation-inference/router/client/src/sharded_client.rs

84 lines
3.0 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Multi shard Client
2022-10-08 10:30:12 +00:00
use crate::Result;
use crate::{Batch, Client, Generation};
2022-10-08 10:30:12 +00:00
use futures::future::join_all;
use tonic::transport::Uri;
2023-02-13 12:02:45 +00:00
use tracing::instrument;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Text Generation Inference gRPC multi client
2022-10-08 10:30:12 +00:00
pub struct ShardedClient {
2022-10-22 21:40:05 +00:00
clients: Vec<Client>,
2022-10-08 10:30:12 +00:00
}
impl ShardedClient {
2022-10-18 13:19:03 +00:00
fn new(clients: Vec<Client>) -> Self {
2022-10-27 12:25:29 +00:00
Self { clients }
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
2022-10-17 12:59:00 +00:00
async fn from_master_client(mut master_client: Client) -> Result<Self> {
2022-10-18 13:19:03 +00:00
// Get all uris/unix sockets from the master client
2022-10-08 10:30:12 +00:00
let uris = master_client.service_discovery().await.unwrap();
2022-10-18 13:19:03 +00:00
let futures = uris.into_iter().map(Client::connect_uds);
2022-10-17 12:59:00 +00:00
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Returns a client connected to the given uri
2022-10-17 12:59:00 +00:00
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
2022-10-08 10:30:12 +00:00
Self::from_master_client(master_client).await
}
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
2022-10-08 10:30:12 +00:00
Self::from_master_client(master_client).await
}
/// Clear the past generations cache
2023-02-13 12:02:45 +00:00
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
2022-10-18 13:19:03 +00:00
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
2022-10-22 21:40:05 +00:00
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
2022-10-22 21:40:05 +00:00
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
2022-10-08 10:30:12 +00:00
}
/// Generate one token for each request in the given cached batches
2022-10-18 13:19:03 +00:00
///
/// Returns Generation for each request in batches
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
2022-10-22 21:40:05 +00:00
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
2022-10-22 21:40:05 +00:00
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
2022-10-22 21:40:05 +00:00
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
2022-10-08 10:30:12 +00:00
}