From 65efd512339bf7c5c32b036218ad2ac0578fc6ca Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 14:57:39 +0100 Subject: [PATCH] cleanup --- router/client/src/client.rs | 17 ++++++++++++++--- router/client/src/lib.rs | 5 +++-- router/client/src/sharded_client.rs | 5 ++++- router/src/db.rs | 6 ++++-- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 47876298..ab5858a7 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -73,7 +73,10 @@ impl Client { /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option, Vec)> { + pub async fn generate( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, Vec)> { let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let response = self .stub @@ -81,7 +84,11 @@ impl Client { .instrument(info_span!("generate")) .await? .into_inner(); - Ok((response.generated_texts, response.batch, response.intermediates)) + Ok(( + response.generated_texts, + response.batch, + response.intermediates, + )) } /// Generate one token for each request in the given cached batch @@ -100,6 +107,10 @@ impl Client { .instrument(info_span!("generate_with_cache")) .await? .into_inner(); - Ok((response.generated_texts, response.batch, response.intermediates)) + Ok(( + response.generated_texts, + response.batch, + response.intermediates, + )) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 13cc196b..22e77ee7 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,8 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, + Batch, GeneratedText, Intermediate, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; @@ -41,4 +42,4 @@ pub struct IntermediateEvent { pub token: String, pub token_id: u32, pub logprob: f32, -} \ No newline at end of file +} diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index e8a939e8..2c370a51 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -41,7 +41,10 @@ impl ShardedClient { /// /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option, Vec)> { + pub async fn generate( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, Vec)> { let futures: Vec<_> = self .clients .iter_mut() diff --git a/router/src/db.rs b/router/src/db.rs index 1c27fc4d..2576e466 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -5,7 +5,8 @@ use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, + Batch, ClientError, Intermediate, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, }; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -18,7 +19,8 @@ pub(crate) struct Entry { /// Response sender to communicate between the Batcher and the batching_task pub response_tx: Sender>, /// Intermediate sender to communicate between the Batcher and the batching_task - pub intermediate_tx: Option, ClientError>>>, + pub intermediate_tx: + Option, ClientError>>>, /// Number of tokens in the input pub input_length: usize, /// Instant when this entry was created