added logit_bias to benchmarks.

This commit is contained in:
marcusdunn 2023-08-15 13:48:39 -07:00
parent a06b681673
commit 9b3be8f79b
3 changed files with 16 additions and 2 deletions

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
use text_generation_client::{LogitBias, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
@ -31,6 +31,7 @@ pub async fn run(
repetition_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
logit_bias: Vec<(String, f32)>,
client: ShardedClient,
) -> Result<(), crossterm::ErrorKind> {
let parameters = NextTokenChooserParameters {
@ -42,7 +43,10 @@ pub async fn run(
seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0),
watermark,
logit_bias: vec![],
logit_bias: logit_bias
.iter()
.map(|(string, bias)| LogitBias { string: string.clone(), bias: *bias })
.collect()
};
// Initialize terminal properties
@ -140,6 +144,7 @@ pub async fn run(
repetition_penalty,
watermark,
do_sample,
logit_bias,
);
println!("\n{parameters_table}\n");

View File

@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(default_values_t="vec![]", long, env)]
logit_bias: Vec<(String, f32)>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark,
do_sample,
master_shard_uds_path,
logit_bias,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -182,6 +188,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
repetition_penalty,
watermark,
do_sample,
logit_bias,
sharded_client,
)
.await

View File

@ -16,6 +16,7 @@ pub(crate) fn parameters_table(
repetition_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
logit_bias: Vec<(String, f32)>
) -> Table {
let mut builder = Builder::default();
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]);
builder.push_record(["Logit Bias", &format!("{logit_bias:?}")]);
let mut table = builder.build();
table.with(Style::markdown());