2022-10-18 13:19:03 +00:00
|
|
|
/// Single shard Client
|
2022-10-11 14:50:54 +00:00
|
|
|
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
2022-10-08 10:30:12 +00:00
|
|
|
use crate::pb::generate::v1::*;
|
|
|
|
use crate::Result;
|
|
|
|
use tonic::transport::{Channel, Uri};
|
|
|
|
use tracing::*;
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Text Generation Inference gRPC client
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Clone)]
|
|
|
|
pub struct Client {
|
2022-10-17 12:59:00 +00:00
|
|
|
stub: TextGenerationServiceClient<Channel>,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Client {
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given url
|
|
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
|
|
let channel = Channel::builder(uri).connect().await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
Ok(Self {
|
|
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given unix socket
|
|
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
2022-10-11 14:50:54 +00:00
|
|
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
2022-10-08 10:30:12 +00:00
|
|
|
.unwrap()
|
|
|
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
|
|
|
tokio::net::UnixStream::connect(path.clone())
|
|
|
|
}))
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
Ok(Self {
|
|
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Returns a list of uris or unix sockets of all shards
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
2022-10-11 14:50:54 +00:00
|
|
|
let request = tonic::Request::new(ServiceDiscoveryRequest {});
|
2022-10-08 10:30:12 +00:00
|
|
|
let response = self
|
|
|
|
.stub
|
|
|
|
.service_discovery(request)
|
|
|
|
.instrument(info_span!("service_discovery"))
|
|
|
|
.await?;
|
|
|
|
let urls = response
|
|
|
|
.into_inner()
|
|
|
|
.urls
|
|
|
|
.into_iter()
|
2022-10-18 13:19:03 +00:00
|
|
|
// Remove unix socket prefix
|
2022-10-08 10:30:12 +00:00
|
|
|
.map(|url| match url.strip_prefix("unix://") {
|
|
|
|
None => url,
|
|
|
|
Some(stripped_url) => stripped_url.to_string(),
|
|
|
|
})
|
|
|
|
.collect();
|
|
|
|
Ok(urls)
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Clear the past generations cache
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn clear_cache(&mut self) -> Result<()> {
|
2022-10-11 14:50:54 +00:00
|
|
|
let request = tonic::Request::new(ClearCacheRequest {});
|
2022-10-08 10:30:12 +00:00
|
|
|
self.stub
|
|
|
|
.clear_cache(request)
|
|
|
|
.instrument(info_span!("clear_cache"))
|
|
|
|
.await?;
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate one token for each request in the given batch
|
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batch
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
|
|
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
|
2022-10-08 10:30:12 +00:00
|
|
|
let response = self
|
|
|
|
.stub
|
2023-01-31 16:04:00 +00:00
|
|
|
.prefill(request)
|
|
|
|
.instrument(info_span!("prefill"))
|
2022-10-08 10:30:12 +00:00
|
|
|
.await?
|
|
|
|
.into_inner();
|
2023-01-31 16:04:00 +00:00
|
|
|
Ok((response.generations, response.batch))
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Generate one token for each request in the given cached batches
|
2022-10-18 13:19:03 +00:00
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batches
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub async fn decode(
|
2022-10-08 10:30:12 +00:00
|
|
|
&mut self,
|
2022-10-11 14:50:54 +00:00
|
|
|
batches: Vec<Batch>,
|
2023-01-31 16:04:00 +00:00
|
|
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
|
|
|
let request = tonic::Request::new(DecodeRequest { batches });
|
2022-10-08 10:30:12 +00:00
|
|
|
let response = self
|
|
|
|
.stub
|
2023-01-31 16:04:00 +00:00
|
|
|
.decode(request)
|
|
|
|
.instrument(info_span!("decode"))
|
2022-10-08 10:30:12 +00:00
|
|
|
.await?
|
|
|
|
.into_inner();
|
2023-01-31 16:04:00 +00:00
|
|
|
Ok((response.generations, response.batch))
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|