mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 04:52:07 +00:00
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
123 lines
4.2 KiB
Rust
123 lines
4.2 KiB
Rust
/// Multi shard Client
|
|
use crate::Result;
|
|
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
|
use futures::future::join_all;
|
|
use tonic::transport::Uri;
|
|
use tracing::instrument;
|
|
|
|
#[derive(Debug, Clone)]
|
|
/// Text Generation Inference gRPC multi client
|
|
pub struct ShardedClient {
|
|
clients: Vec<Client>,
|
|
}
|
|
|
|
impl ShardedClient {
|
|
fn new(clients: Vec<Client>) -> Self {
|
|
Self { clients }
|
|
}
|
|
|
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
|
// Get all uris/unix sockets from the master client
|
|
let uris = master_client.service_discovery().await.unwrap();
|
|
let futures = uris.into_iter().map(Client::connect_uds);
|
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
|
Ok(Self::new(clients?))
|
|
}
|
|
|
|
/// Returns a client connected to the given uri
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
let master_client = Client::connect(uri).await?;
|
|
Self::from_master_client(master_client).await
|
|
}
|
|
|
|
/// Returns a client connected to the given unix socket
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
|
let master_client = Client::connect_uds(path).await?;
|
|
Self::from_master_client(master_client).await
|
|
}
|
|
|
|
/// Get the model info
|
|
#[instrument(skip(self))]
|
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| client.info())
|
|
.collect();
|
|
join_all(futures).await.pop().unwrap()
|
|
}
|
|
|
|
/// GRPC health check
|
|
#[instrument(skip(self))]
|
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| client.health())
|
|
.collect();
|
|
join_all(futures).await.pop().unwrap()
|
|
}
|
|
|
|
/// Clear the past generations cache
|
|
#[instrument(skip(self))]
|
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| client.clear_cache(batch_id))
|
|
.collect();
|
|
join_all(futures).await.into_iter().collect()
|
|
}
|
|
|
|
/// Filter a cached batch
|
|
#[instrument(skip(self))]
|
|
pub async fn filter_batch(
|
|
&mut self,
|
|
batch_id: u64,
|
|
keep_requests: Vec<Request>,
|
|
) -> Result<Option<Batch>> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
|
|
.collect();
|
|
// all shards return the same message
|
|
join_all(futures).await.pop().unwrap()
|
|
}
|
|
|
|
/// Generate one token for each request in the given batch
|
|
///
|
|
/// Returns Generation for each request in batch
|
|
/// and the next cached batch
|
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
|
.collect();
|
|
// all shards return the same message
|
|
join_all(futures).await.pop().unwrap()
|
|
}
|
|
|
|
/// Generate one token for each request in the given cached batches
|
|
///
|
|
/// Returns Generation for each request in batches
|
|
/// and the next cached batch
|
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
|
pub async fn decode(
|
|
&mut self,
|
|
batches: Vec<Batch>,
|
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
|
let futures: Vec<_> = self
|
|
.clients
|
|
.iter_mut()
|
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
|
.collect();
|
|
// all shards return the same message
|
|
join_all(futures).await.pop().unwrap()
|
|
}
|
|
}
|