mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added logit_bias
to benchmarks.
This commit is contained in:
parent
a06b681673
commit
9b3be8f79b
@ -8,7 +8,7 @@ use crate::app::App;
|
||||
use crate::event::Event;
|
||||
use crossterm::ExecutableCommand;
|
||||
use std::io;
|
||||
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
|
||||
use text_generation_client::{LogitBias, NextTokenChooserParameters, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tui::backend::CrosstermBackend;
|
||||
@ -31,6 +31,7 @@ pub async fn run(
|
||||
repetition_penalty: Option<f32>,
|
||||
watermark: bool,
|
||||
do_sample: bool,
|
||||
logit_bias: Vec<(String, f32)>,
|
||||
client: ShardedClient,
|
||||
) -> Result<(), crossterm::ErrorKind> {
|
||||
let parameters = NextTokenChooserParameters {
|
||||
@ -42,7 +43,10 @@ pub async fn run(
|
||||
seed: 0,
|
||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||
watermark,
|
||||
logit_bias: vec![],
|
||||
logit_bias: logit_bias
|
||||
.iter()
|
||||
.map(|(string, bias)| LogitBias { string: string.clone(), bias: *bias })
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
@ -140,6 +144,7 @@ pub async fn run(
|
||||
repetition_penalty,
|
||||
watermark,
|
||||
do_sample,
|
||||
logit_bias,
|
||||
);
|
||||
println!("\n{parameters_table}\n");
|
||||
|
||||
|
@ -93,6 +93,11 @@ struct Args {
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
do_sample: bool,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(default_values_t="vec![]", long, env)]
|
||||
logit_bias: Vec<(String, f32)>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
watermark,
|
||||
do_sample,
|
||||
master_shard_uds_path,
|
||||
logit_bias,
|
||||
} = args;
|
||||
|
||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||
@ -182,6 +188,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
repetition_penalty,
|
||||
watermark,
|
||||
do_sample,
|
||||
logit_bias,
|
||||
sharded_client,
|
||||
)
|
||||
.await
|
||||
|
@ -16,6 +16,7 @@ pub(crate) fn parameters_table(
|
||||
repetition_penalty: Option<f32>,
|
||||
watermark: bool,
|
||||
do_sample: bool,
|
||||
logit_bias: Vec<(String, f32)>
|
||||
) -> Table {
|
||||
let mut builder = Builder::default();
|
||||
|
||||
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
|
||||
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
||||
builder.push_record(["Watermark", &watermark.to_string()]);
|
||||
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
||||
builder.push_record(["Logit Bias", &format!("{logit_bias:?}")]);
|
||||
|
||||
let mut table = builder.build();
|
||||
table.with(Style::markdown());
|
||||
|
Loading…
Reference in New Issue
Block a user