text-generation-inference/router/client/src/client.rs

106 lines
3.5 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
2022-10-08 10:30:12 +00:00
use crate::pb::generate::v1::*;
use crate::Result;
use tonic::transport::{Channel, Uri};
use tracing::*;
2022-10-18 13:19:03 +00:00
/// Text Generation Inference gRPC client
2022-10-08 10:30:12 +00:00
#[derive(Clone)]
pub struct Client {
2022-10-17 12:59:00 +00:00
stub: TextGenerationServiceClient<Channel>,
2022-10-08 10:30:12 +00:00
}
impl Client {
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
2022-10-08 10:30:12 +00:00
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
2022-10-17 12:59:00 +00:00
.await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Returns a list of uris or unix sockets of all shards
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {});
2022-10-08 10:30:12 +00:00
let response = self
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let urls = response
.into_inner()
.urls
.into_iter()
2022-10-18 13:19:03 +00:00
// Remove unix socket prefix
2022-10-08 10:30:12 +00:00
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
2022-10-18 13:19:03 +00:00
/// Clear the past generations cache
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {});
2022-10-08 10:30:12 +00:00
self.stub
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
Ok(())
}
2022-10-18 13:19:03 +00:00
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
2022-10-08 10:30:12 +00:00
let response = self
.stub
.generate(request)
.instrument(info_span!("generate"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn generate_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
2022-10-08 10:30:12 +00:00
let response = self
.stub
.generate_with_cache(request)
.instrument(info_span!("generate_with_cache"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
}
2022-10-08 10:30:12 +00:00
}