Add top-n-tokens support to benchmark

This commit is contained in:
Vincent Brouwers 2023-07-24 14:02:56 +00:00 committed by Nicolas Patry
parent 8a4d2076a6
commit 0facd94738
4 changed files with 17 additions and 6 deletions

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -73,9 +75,6 @@ async fn generate_runs(
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
// TODO: Implement top_n_tokens
let top_n_tokens= 0;
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
@ -135,7 +134,7 @@ async fn prefill(
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: u32,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
@ -151,7 +150,7 @@ async fn prefill(
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens,
top_n_tokens: top_n_tokens.unwrap_or(0),
})
.collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -70,6 +71,7 @@ pub async fn run(
batch_size.clone(),
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
parameters,
@ -130,6 +132,7 @@ pub async fn run(
tokenizer_name,
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
temperature,

View File

@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark,
do_sample,
master_shard_uds_path,
top_n_tokens,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size,
sequence_length,
decode_length,
top_n_tokens,
runs,
warmups,
temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]);