This commit is contained in:
Yannic Kilcher 2023-01-26 14:57:39 +01:00
parent 7beb968696
commit 65efd51233
4 changed files with 25 additions and 8 deletions

View File

@ -73,7 +73,10 @@ impl Client {
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
pub async fn generate(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
let response = self
.stub
@ -81,7 +84,11 @@ impl Client {
.instrument(info_span!("generate"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch, response.intermediates))
Ok((
response.generated_texts,
response.batch,
response.intermediates,
))
}
/// Generate one token for each request in the given cached batch
@ -100,6 +107,10 @@ impl Client {
.instrument(info_span!("generate_with_cache"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch, response.intermediates))
Ok((
response.generated_texts,
response.batch,
response.intermediates,
))
}
}

View File

@ -7,7 +7,8 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
Batch, GeneratedText, Intermediate, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -41,7 +41,10 @@ impl ShardedClient {
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
pub async fn generate(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self
.clients
.iter_mut()

View File

@ -5,7 +5,8 @@ use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
Batch, ClientError, Intermediate, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
};
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
@ -18,7 +19,8 @@ pub(crate) struct Entry {
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Intermediate sender to communicate between the Batcher and the batching_task
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
pub intermediate_tx:
Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created