diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index d6433612..014fb20e 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{NextTokenChooserParameters, ShardedClient}; +use text_generation_client::{LogitBias, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; @@ -31,6 +31,7 @@ pub async fn run( repetition_penalty: Option, watermark: bool, do_sample: bool, + logit_bias: Vec<(String, f32)>, client: ShardedClient, ) -> Result<(), crossterm::ErrorKind> { let parameters = NextTokenChooserParameters { @@ -42,7 +43,10 @@ pub async fn run( seed: 0, repetition_penalty: repetition_penalty.unwrap_or(1.0), watermark, - logit_bias: vec![], + logit_bias: logit_bias + .iter() + .map(|(string, bias)| LogitBias { string: string.clone(), bias: *bias }) + .collect() }; // Initialize terminal properties @@ -140,6 +144,7 @@ pub async fn run( repetition_penalty, watermark, do_sample, + logit_bias, ); println!("\n{parameters_table}\n"); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index a7550060..f32b1ad4 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -93,6 +93,11 @@ struct Args { /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] do_sample: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(default_values_t="vec![]", long, env)] + logit_bias: Vec<(String, f32)>, } fn main() -> Result<(), Box> { @@ -117,6 +122,7 @@ fn main() -> Result<(), Box> { watermark, do_sample, master_shard_uds_path, + logit_bias, } = args; let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); @@ -182,6 +188,7 @@ fn main() -> Result<(), Box> { repetition_penalty, watermark, do_sample, + logit_bias, sharded_client, ) .await diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index 6b74bc36..999b1859 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -16,6 +16,7 @@ pub(crate) fn parameters_table( repetition_penalty: Option, watermark: bool, do_sample: bool, + logit_bias: Vec<(String, f32)> ) -> Table { let mut builder = Builder::default(); @@ -33,6 +34,7 @@ pub(crate) fn parameters_table( builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]); + builder.push_record(["Logit Bias", &format!("{logit_bias:?}")]); let mut table = builder.build(); table.with(Style::markdown());