diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 1c30b1ed..930d0d72 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -37,6 +37,7 @@ pub(crate) async fn generation_task( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, parameters: NextTokenChooserParameters, @@ -48,7 +49,7 @@ pub(crate) async fn generation_task( // End task if a message is received on shutdown_receiver // _shutdown_guard_sender will be dropped once the task is finished tokio::select! { - res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { + res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => { if let Err(err) = res { run_sender.send(Err(err)).await.unwrap_or(()); } @@ -64,6 +65,7 @@ async fn generate_runs( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, parameters: NextTokenChooserParameters, @@ -73,9 +75,6 @@ async fn generate_runs( // Create a dummy sequence let sequence = create_sequence(sequence_length, tokenizer); - // TODO: Implement top_n_tokens - let top_n_tokens= 0; - for b in batch_size { // Warmups on batch size for _ in 0..warmups { @@ -135,7 +134,7 @@ async fn prefill( batch_size: u32, decode_length: u32, parameters: NextTokenChooserParameters, - top_n_tokens: u32, + top_n_tokens: Option, client: &mut ShardedClient, ) -> Result<(Prefill, CachedBatch), ClientError> { // Create requests @@ -151,7 +150,7 @@ async fn prefill( stop_sequences: vec![], ignore_eos_token: true, // Will not stop even if a eos token is generated }), - top_n_tokens: top_n_tokens, + top_n_tokens: top_n_tokens.unwrap_or(0), }) .collect(); diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index fcad400c..433c6f67 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -22,6 +22,7 @@ pub async fn run( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, temperature: Option, @@ -70,6 +71,7 @@ pub async fn run( batch_size.clone(), sequence_length, decode_length, + top_n_tokens, n_runs, warmups, parameters, @@ -130,6 +132,7 @@ pub async fn run( tokenizer_name, sequence_length, decode_length, + top_n_tokens, n_runs, warmups, temperature, diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index a7550060..97c8af1c 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -93,6 +93,11 @@ struct Args { /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] do_sample: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_n_tokens: Option, } fn main() -> Result<(), Box> { @@ -117,6 +122,7 @@ fn main() -> Result<(), Box> { watermark, do_sample, master_shard_uds_path, + top_n_tokens, } = args; let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); @@ -173,6 +179,7 @@ fn main() -> Result<(), Box> { batch_size, sequence_length, decode_length, + top_n_tokens, runs, warmups, temperature, diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index 6b74bc36..9e36717b 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -7,6 +7,7 @@ pub(crate) fn parameters_table( tokenizer_name: String, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, temperature: Option, @@ -24,6 +25,7 @@ pub(crate) fn parameters_table( builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]); + builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]); builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Temperature", &format!("{temperature:?}")]);