mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added logit_bias
to benchmarks.
This commit is contained in:
parent
a06b681673
commit
9b3be8f79b
@ -8,7 +8,7 @@ use crate::app::App;
|
|||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::{LogitBias, NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
use tui::backend::CrosstermBackend;
|
||||||
@ -31,6 +31,7 @@ pub async fn run(
|
|||||||
repetition_penalty: Option<f32>,
|
repetition_penalty: Option<f32>,
|
||||||
watermark: bool,
|
watermark: bool,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
logit_bias: Vec<(String, f32)>,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
) -> Result<(), crossterm::ErrorKind> {
|
) -> Result<(), crossterm::ErrorKind> {
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
@ -42,7 +43,10 @@ pub async fn run(
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
watermark,
|
watermark,
|
||||||
logit_bias: vec![],
|
logit_bias: logit_bias
|
||||||
|
.iter()
|
||||||
|
.map(|(string, bias)| LogitBias { string: string.clone(), bias: *bias })
|
||||||
|
.collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
@ -140,6 +144,7 @@ pub async fn run(
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
|
logit_bias,
|
||||||
);
|
);
|
||||||
println!("\n{parameters_table}\n");
|
println!("\n{parameters_table}\n");
|
||||||
|
|
||||||
|
@ -93,6 +93,11 @@ struct Args {
|
|||||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(default_values_t="vec![]", long, env)]
|
||||||
|
logit_bias: Vec<(String, f32)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
logit_bias,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||||
@ -182,6 +188,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
|
logit_bias,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -16,6 +16,7 @@ pub(crate) fn parameters_table(
|
|||||||
repetition_penalty: Option<f32>,
|
repetition_penalty: Option<f32>,
|
||||||
watermark: bool,
|
watermark: bool,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
logit_bias: Vec<(String, f32)>
|
||||||
) -> Table {
|
) -> Table {
|
||||||
let mut builder = Builder::default();
|
let mut builder = Builder::default();
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
|
|||||||
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
|
||||||
builder.push_record(["Watermark", &watermark.to_string()]);
|
builder.push_record(["Watermark", &watermark.to_string()]);
|
||||||
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
builder.push_record(["Do Sample", &do_sample.to_string()]);
|
||||||
|
builder.push_record(["Logit Bias", &format!("{logit_bias:?}")]);
|
||||||
|
|
||||||
let mut table = builder.build();
|
let mut table = builder.build();
|
||||||
table.with(Style::markdown());
|
table.with(Style::markdown());
|
||||||
|
Loading…
Reference in New Issue
Block a user