2022-10-18 13:19:03 +00:00
|
|
|
/// Multi shard Client
|
2024-06-04 13:56:56 +00:00
|
|
|
use crate::{v3, Health, ShardInfo};
|
2023-05-10 13:48:21 +00:00
|
|
|
use crate::{ClientError, Result};
|
2024-06-04 13:56:56 +00:00
|
|
|
|
|
|
|
use crate::v3::{Chunk, InfoResponse, Input};
|
|
|
|
use async_trait::async_trait;
|
2022-10-08 10:30:12 +00:00
|
|
|
use futures::future::join_all;
|
|
|
|
use tonic::transport::Uri;
|
2023-02-13 12:02:45 +00:00
|
|
|
use tracing::instrument;
|
2024-06-04 13:56:56 +00:00
|
|
|
use v3::client::{DecodeTimings, PrefillTimings};
|
|
|
|
use v3::{
|
|
|
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
|
|
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
|
|
|
};
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2023-04-26 18:23:54 +00:00
|
|
|
#[derive(Debug, Clone)]
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Text Generation Inference gRPC multi client
|
2022-10-08 10:30:12 +00:00
|
|
|
pub struct ShardedClient {
|
2022-10-22 21:40:05 +00:00
|
|
|
clients: Vec<Client>,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl ShardedClient {
|
2022-10-18 13:19:03 +00:00
|
|
|
fn new(clients: Vec<Client>) -> Self {
|
2022-10-27 12:25:29 +00:00
|
|
|
Self { clients }
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
|
|
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
2022-10-17 12:59:00 +00:00
|
|
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Get all uris/unix sockets from the master client
|
2023-04-27 17:16:35 +00:00
|
|
|
let uris = master_client.service_discovery().await?;
|
2022-10-18 13:19:03 +00:00
|
|
|
let futures = uris.into_iter().map(Client::connect_uds);
|
2022-10-17 12:59:00 +00:00
|
|
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
|
|
|
Ok(Self::new(clients?))
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Returns a client connected to the given uri
|
2022-10-17 12:59:00 +00:00
|
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
|
|
let master_client = Client::connect(uri).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
Self::from_master_client(master_client).await
|
|
|
|
}
|
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given unix socket
|
|
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
|
|
|
let master_client = Client::connect_uds(path).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
Self::from_master_client(master_client).await
|
|
|
|
}
|
|
|
|
|
2023-04-21 13:36:29 +00:00
|
|
|
/// Get the model info
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
|
|
|
.map(|client| client.info())
|
|
|
|
.collect();
|
2024-06-04 13:56:56 +00:00
|
|
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
2023-04-21 13:36:29 +00:00
|
|
|
}
|
|
|
|
|
2023-04-26 18:23:54 +00:00
|
|
|
/// GRPC health check
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
|
|
|
.map(|client| client.health())
|
|
|
|
.collect();
|
|
|
|
join_all(futures).await.pop().unwrap()
|
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Clear the past generations cache
|
2023-02-13 12:02:45 +00:00
|
|
|
#[instrument(skip(self))]
|
2023-03-28 09:29:35 +00:00
|
|
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
2023-01-31 16:04:00 +00:00
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
2023-03-28 09:29:35 +00:00
|
|
|
.map(|client| client.clear_cache(batch_id))
|
2023-01-31 16:04:00 +00:00
|
|
|
.collect();
|
|
|
|
join_all(futures).await.into_iter().collect()
|
|
|
|
}
|
|
|
|
|
2023-04-24 15:59:00 +00:00
|
|
|
/// Filter a cached batch
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn filter_batch(
|
|
|
|
&mut self,
|
|
|
|
batch_id: u64,
|
2023-05-24 17:19:57 +00:00
|
|
|
request_ids: Vec<u64>,
|
|
|
|
) -> Result<Option<CachedBatch>> {
|
2023-04-24 15:59:00 +00:00
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
2023-05-24 17:19:57 +00:00
|
|
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
2023-04-24 15:59:00 +00:00
|
|
|
.collect();
|
|
|
|
// all shards return the same message
|
|
|
|
join_all(futures).await.pop().unwrap()
|
|
|
|
}
|
|
|
|
|
2023-06-30 17:09:59 +00:00
|
|
|
/// Warmup on a max size batch
|
|
|
|
///
|
|
|
|
/// Returns the maximum amount of tokens supported by the hardware
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn warmup(
|
|
|
|
&mut self,
|
|
|
|
max_input_length: u32,
|
|
|
|
max_prefill_tokens: u32,
|
2023-10-20 08:28:45 +00:00
|
|
|
max_total_tokens: u32,
|
2024-02-09 11:38:41 +00:00
|
|
|
max_batch_size: Option<usize>,
|
2023-07-19 07:31:25 +00:00
|
|
|
) -> Result<Option<u32>> {
|
2023-06-30 17:09:59 +00:00
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
2023-10-20 08:28:45 +00:00
|
|
|
.map(|client| {
|
2024-02-09 11:38:41 +00:00
|
|
|
Box::pin(client.warmup(
|
|
|
|
max_input_length,
|
|
|
|
max_prefill_tokens,
|
|
|
|
max_total_tokens,
|
|
|
|
max_batch_size,
|
|
|
|
))
|
2023-10-20 08:28:45 +00:00
|
|
|
})
|
2023-06-30 17:09:59 +00:00
|
|
|
.collect();
|
2023-07-24 09:43:58 +00:00
|
|
|
// Take the minimum value
|
|
|
|
let results = join_all(futures)
|
|
|
|
.await
|
|
|
|
.into_iter()
|
|
|
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
|
|
|
Ok(results.into_iter().flatten().min())
|
2023-06-30 17:09:59 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate one token for each request in the given batch
|
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batch
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2023-12-14 14:59:38 +00:00
|
|
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
2023-05-24 17:19:57 +00:00
|
|
|
pub async fn prefill(
|
|
|
|
&mut self,
|
|
|
|
batch: Batch,
|
2024-10-16 10:49:33 +00:00
|
|
|
cached_batch: Option<CachedBatch>,
|
2023-12-14 14:59:38 +00:00
|
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
2022-10-22 21:40:05 +00:00
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
2024-10-16 10:49:33 +00:00
|
|
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
2022-10-22 21:40:05 +00:00
|
|
|
.collect();
|
2024-01-22 14:22:54 +00:00
|
|
|
#[allow(clippy::type_complexity)]
|
2023-12-14 14:59:38 +00:00
|
|
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
2023-05-10 13:48:21 +00:00
|
|
|
join_all(futures).await.into_iter().collect();
|
2023-12-14 14:59:38 +00:00
|
|
|
let mut results = results?;
|
|
|
|
|
|
|
|
let (mut generations, next_batch, mut timings) =
|
|
|
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
|
|
|
|
|
|
|
// Merge generations from different model shards
|
|
|
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
|
|
|
generations.append(&mut shard_generations);
|
|
|
|
// Return the timings of the slowest shard
|
|
|
|
if shard_timings.total > timings.total {
|
|
|
|
timings = shard_timings;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok((generations, next_batch, timings))
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Generate one token for each request in the given cached batches
|
2022-10-18 13:19:03 +00:00
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batches
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2023-12-14 14:59:38 +00:00
|
|
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub async fn decode(
|
2022-10-22 21:40:05 +00:00
|
|
|
&mut self,
|
2023-05-24 17:19:57 +00:00
|
|
|
batches: Vec<CachedBatch>,
|
2023-12-14 14:59:38 +00:00
|
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
2022-10-22 21:40:05 +00:00
|
|
|
let futures: Vec<_> = self
|
|
|
|
.clients
|
|
|
|
.iter_mut()
|
2023-01-31 16:04:00 +00:00
|
|
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
2022-10-22 21:40:05 +00:00
|
|
|
.collect();
|
2024-01-22 14:22:54 +00:00
|
|
|
#[allow(clippy::type_complexity)]
|
2023-12-14 14:59:38 +00:00
|
|
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
2023-05-10 13:48:21 +00:00
|
|
|
join_all(futures).await.into_iter().collect();
|
2023-12-14 14:59:38 +00:00
|
|
|
let mut results = results?;
|
2023-05-10 13:48:21 +00:00
|
|
|
|
2023-12-14 14:59:38 +00:00
|
|
|
let (mut generations, next_batch, mut timings) =
|
|
|
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
2023-05-10 13:48:21 +00:00
|
|
|
|
2023-12-14 14:59:38 +00:00
|
|
|
// Merge generations from different model shards
|
|
|
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
|
|
|
generations.append(&mut shard_generations);
|
|
|
|
// Return the timings of the slowest shard
|
|
|
|
if shard_timings.total > timings.total {
|
|
|
|
timings = shard_timings;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok((generations, next_batch, timings))
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
2024-06-04 13:56:56 +00:00
|
|
|
|
|
|
|
impl From<InfoResponse> for ShardInfo {
|
|
|
|
fn from(value: InfoResponse) -> Self {
|
|
|
|
Self {
|
|
|
|
requires_padding: value.requires_padding,
|
|
|
|
dtype: value.dtype,
|
|
|
|
device_type: value.device_type,
|
|
|
|
window_size: value.window_size,
|
|
|
|
speculate: value.speculate,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl Health for ShardedClient {
|
|
|
|
async fn device_health(&self) -> Result<()> {
|
|
|
|
self.clone().health().await?;
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn model_health(&self) -> Result<()> {
|
|
|
|
// Dummy batch of 1 token and 1 generated token
|
|
|
|
let liveness_request = Request {
|
|
|
|
id: u64::MAX,
|
|
|
|
inputs: "liveness".to_string(),
|
|
|
|
input_chunks: Some(Input {
|
|
|
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
|
|
|
}),
|
|
|
|
truncate: 10,
|
2024-08-29 14:29:01 +00:00
|
|
|
add_special_tokens: true,
|
2024-06-04 13:56:56 +00:00
|
|
|
prefill_logprobs: false,
|
|
|
|
parameters: Some(NextTokenChooserParameters {
|
|
|
|
temperature: 1.0,
|
|
|
|
top_k: 0,
|
|
|
|
top_p: 1.0,
|
|
|
|
typical_p: 1.0,
|
|
|
|
do_sample: false,
|
|
|
|
seed: 0,
|
|
|
|
repetition_penalty: 1.0,
|
|
|
|
frequency_penalty: 0.0,
|
|
|
|
watermark: false,
|
|
|
|
grammar: String::new(),
|
|
|
|
grammar_type: GrammarType::None as i32,
|
|
|
|
}),
|
|
|
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
|
|
|
max_new_tokens: 1,
|
|
|
|
stop_sequences: vec![],
|
|
|
|
ignore_eos_token: false,
|
|
|
|
}),
|
|
|
|
top_n_tokens: 0,
|
2024-06-05 10:18:38 +00:00
|
|
|
// Block 0 is reserved for health checks
|
|
|
|
blocks: vec![0],
|
|
|
|
slots: (0..16).collect(),
|
2024-10-16 10:49:33 +00:00
|
|
|
cache_len: 0,
|
|
|
|
chunk_len: None,
|
2024-06-25 18:46:27 +00:00
|
|
|
adapter_id: None,
|
2024-06-04 13:56:56 +00:00
|
|
|
};
|
|
|
|
let batch = Batch {
|
|
|
|
id: u64::MAX,
|
|
|
|
requests: vec![liveness_request],
|
|
|
|
size: 1,
|
|
|
|
max_tokens: 2,
|
2024-06-05 10:18:38 +00:00
|
|
|
max_blocks: 1,
|
2024-06-04 13:56:56 +00:00
|
|
|
};
|
2024-10-16 10:49:33 +00:00
|
|
|
self.clone().prefill(batch, None).await?;
|
2024-06-04 13:56:56 +00:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|