This commit is contained in:
Yannic Kilcher 2023-01-26 14:57:39 +01:00
parent 7beb968696
commit 65efd51233
4 changed files with 25 additions and 8 deletions

View File

@ -73,7 +73,10 @@ impl Client {
/// Returns a list of generated texts of request that met their stopping criteria /// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> { pub async fn generate(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
let response = self let response = self
.stub .stub
@ -81,7 +84,11 @@ impl Client {
.instrument(info_span!("generate")) .instrument(info_span!("generate"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch, response.intermediates)) Ok((
response.generated_texts,
response.batch,
response.intermediates,
))
} }
/// Generate one token for each request in the given cached batch /// Generate one token for each request in the given cached batch
@ -100,6 +107,10 @@ impl Client {
.instrument(info_span!("generate_with_cache")) .instrument(info_span!("generate_with_cache"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch, response.intermediates)) Ok((
response.generated_texts,
response.batch,
response.intermediates,
))
} }
} }

View File

@ -7,7 +7,8 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, Batch, GeneratedText, Intermediate, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
@ -41,4 +42,4 @@ pub struct IntermediateEvent {
pub token: String, pub token: String,
pub token_id: u32, pub token_id: u32,
pub logprob: f32, pub logprob: f32,
} }

View File

@ -41,7 +41,10 @@ impl ShardedClient {
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch /// and the next cached batch
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> { pub async fn generate(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()

View File

@ -5,7 +5,8 @@ use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, Batch, ClientError, Intermediate, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
}; };
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::Instant; use tokio::time::Instant;
@ -18,7 +19,8 @@ pub(crate) struct Entry {
/// Response sender to communicate between the Batcher and the batching_task /// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<InferResponse, ClientError>>, pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Intermediate sender to communicate between the Batcher and the batching_task /// Intermediate sender to communicate between the Batcher and the batching_task
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>, pub intermediate_tx:
Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
/// Number of tokens in the input /// Number of tokens in the input
pub input_length: usize, pub input_length: usize,
/// Instant when this entry was created /// Instant when this entry was created