Add top-n-tokens support to benchmark

This commit is contained in:
Vincent Brouwers 2023-07-24 14:02:56 +00:00
parent 7c014c7dfe
commit a7be416c87
4 changed files with 17 additions and 6 deletions

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver // End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished // _shutdown_guard_sender will be dropped once the task is finished
tokio::select! { tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res { if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(()); run_sender.send(Err(err)).await.unwrap_or(());
} }
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -73,9 +75,6 @@ async fn generate_runs(
// Create a dummy sequence // Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer); let sequence = create_sequence(sequence_length, tokenizer);
// TODO: Implement top_n_tokens
let top_n_tokens= 0;
for b in batch_size { for b in batch_size {
// Warmups on batch size // Warmups on batch size
for _ in 0..warmups { for _ in 0..warmups {
@ -135,7 +134,7 @@ async fn prefill(
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
top_n_tokens: u32, top_n_tokens: Option<u32>,
client: &mut ShardedClient, client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> { ) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests // Create requests
@ -151,7 +150,7 @@ async fn prefill(
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens, top_n_tokens: top_n_tokens.unwrap_or(0),
}) })
.collect(); .collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -70,6 +71,7 @@ pub async fn run(
batch_size.clone(), batch_size.clone(),
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
parameters, parameters,
@ -130,6 +132,7 @@ pub async fn run(
tokenizer_name, tokenizer_name,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
temperature, temperature,

View File

@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
do_sample: bool, do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
} }
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
top_n_tokens,
} = args; } = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size, batch_size,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
runs, runs,
warmups, warmups,
temperature, temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String, tokenizer_name: String,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]); builder.push_record(["Temperature", &format!("{temperature:?}")]);