text-generation-inference/router/client/src/client.rs

178 lines
6.2 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
2022-10-08 10:30:12 +00:00
use crate::pb::generate::v1::*;
use crate::Result;
2023-02-13 12:02:45 +00:00
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
2022-10-08 10:30:12 +00:00
use tonic::transport::{Channel, Uri};
2023-02-13 12:02:45 +00:00
use tracing::instrument;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
2022-10-08 10:30:12 +00:00
pub struct Client {
2022-10-17 12:59:00 +00:00
stub: TextGenerationServiceClient<Channel>,
2022-10-08 10:30:12 +00:00
}
impl Client {
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
2022-10-08 10:30:12 +00:00
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
2022-10-17 12:59:00 +00:00
.await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Returns a list of uris or unix sockets of all shards
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
2022-10-08 10:30:12 +00:00
let urls = response
.into_inner()
.urls
.into_iter()
2022-10-18 13:19:03 +00:00
// Remove unix socket prefix
2022-10-08 10:30:12 +00:00
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
2022-10-18 13:19:03 +00:00
/// Clear the past generations cache
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
2023-02-13 12:02:45 +00:00
self.stub.clear_cache(request).await?;
2022-10-08 10:30:12 +00:00
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
2022-10-18 13:19:03 +00:00
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
2022-10-08 10:30:12 +00:00
}
/// Generate one token for each request in the given cached batches
2022-10-18 13:19:03 +00:00
///
/// Returns Generation for each request in batches
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
2022-10-08 10:30:12 +00:00
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
}
2022-10-08 10:30:12 +00:00
}