mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
feat(client): Simplify sharded logic
This commit is contained in:
parent
c8ce9b2515
commit
beb552127a
@ -2,76 +2,18 @@
|
|||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, GeneratedText};
|
use crate::{Batch, Client, GeneratedText};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use futures::future::select_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
|
|
||||||
/// List of all available commands that can be sent through the command channel
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
enum Command {
|
|
||||||
Generate(
|
|
||||||
Batch,
|
|
||||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
|
||||||
),
|
|
||||||
GenerateWithCache(
|
|
||||||
Vec<Batch>,
|
|
||||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
|
||||||
),
|
|
||||||
ClearCache(mpsc::Sender<Result<()>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tokio task that handles the communication with a single shard
|
|
||||||
///
|
|
||||||
/// We subscribe on a broadcast channel to receive commands that will be sent by
|
|
||||||
/// the ShardedClient.
|
|
||||||
///
|
|
||||||
/// Each command is fan out to all shards.
|
|
||||||
///
|
|
||||||
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
|
|
||||||
/// producer = the shards, single consumer = the ShardedClient).
|
|
||||||
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
|
||||||
while let Ok(message) = request_subscriber.recv().await {
|
|
||||||
match message {
|
|
||||||
Command::Generate(batch, response_tx) => {
|
|
||||||
let result = client.generate(batch).await;
|
|
||||||
// We can unwrap_or(()) here because the only error that can happen is if the
|
|
||||||
// receiver is dropped, which means that the ShardedClient already received a
|
|
||||||
// response from another shard
|
|
||||||
response_tx.try_send(result).unwrap_or(());
|
|
||||||
}
|
|
||||||
Command::GenerateWithCache(batches, response_tx) => {
|
|
||||||
let result = client.generate_with_cache(batches).await;
|
|
||||||
response_tx.try_send(result).unwrap_or(());
|
|
||||||
}
|
|
||||||
Command::ClearCache(response_tx) => {
|
|
||||||
let result = client.clear_cache().await;
|
|
||||||
response_tx.try_send(result).unwrap_or(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
pub struct ShardedClient {
|
pub struct ShardedClient {
|
||||||
_clients: Vec<Client>,
|
clients: Vec<Client>,
|
||||||
request_tx: broadcast::Sender<Command>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShardedClient {
|
impl ShardedClient {
|
||||||
fn new(clients: Vec<Client>) -> Self {
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
// The broadcast channel to communicate with the shards
|
|
||||||
// We use a capacity of one as the shards are not asynchronous and can only process one
|
|
||||||
// command at a time
|
|
||||||
let (request_tx, _) = broadcast::channel(1);
|
|
||||||
|
|
||||||
// Spawn client tasks
|
|
||||||
for client in clients.iter() {
|
|
||||||
let request_subscriber = request_tx.subscribe();
|
|
||||||
tokio::spawn(client_task(client.clone(), request_subscriber));
|
|
||||||
}
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
_clients: clients,
|
clients,
|
||||||
request_tx,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,15 +43,15 @@ impl ShardedClient {
|
|||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||||
// Create a channel to receive the response from the shards
|
let futures: Vec<_> = self
|
||||||
// We will only ever receive one message on this channel
|
.clients
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
.iter_mut()
|
||||||
self.request_tx
|
.map(|client| Box::pin(client.generate(batch.clone())))
|
||||||
.send(Command::Generate(batch, response_tx))
|
.collect();
|
||||||
.unwrap();
|
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
response_rx.recv().await.unwrap()
|
let (result, _, _) = select_all(futures).await;
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
/// Generate one token for each request in the given cached batch
|
||||||
@ -117,27 +59,26 @@ impl ShardedClient {
|
|||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
pub async fn generate_with_cache(
|
pub async fn generate_with_cache(
|
||||||
&self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||||
// Create a channel to receive the response from the shards
|
let futures: Vec<_> = self
|
||||||
// We will only ever receive one message on this channel
|
.clients
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
.iter_mut()
|
||||||
self.request_tx
|
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
|
||||||
.send(Command::GenerateWithCache(batches, response_tx))
|
.collect();
|
||||||
.unwrap();
|
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
response_rx.recv().await.unwrap()
|
let (result, _, _) = select_all(futures).await;
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
pub async fn clear_cache(&self) -> Result<()> {
|
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||||
// Create a channel to receive the response from the shards
|
let futures: Vec<_> = self
|
||||||
// We will only ever receive one message on this channel
|
.clients
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
.iter_mut()
|
||||||
self.request_tx
|
.map(|client| client.clear_cache())
|
||||||
.send(Command::ClearCache(response_tx))
|
.collect();
|
||||||
.unwrap();
|
join_all(futures).await.into_iter().collect()
|
||||||
response_rx.recv().await.unwrap()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,9 +39,9 @@ impl Batcher {
|
|||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
|
client,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
client,
|
|
||||||
db.clone(),
|
db.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
));
|
));
|
||||||
@ -86,9 +86,9 @@ impl Batcher {
|
|||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[instrument(skip(client, db, shared))]
|
#[instrument(skip(client, db, shared))]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
|
||||||
db: Db,
|
db: Db,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
) {
|
) {
|
||||||
|
@ -61,7 +61,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
Loading…
Reference in New Issue
Block a user