mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Add top-n-tokens support to benchmark
This commit is contained in:
parent
7c014c7dfe
commit
a7be416c87
@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
|
||||
// End task if a message is received on shutdown_receiver
|
||||
// _shutdown_guard_sender will be dropped once the task is finished
|
||||
tokio::select! {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
if let Err(err) = res {
|
||||
run_sender.send(Err(err)).await.unwrap_or(());
|
||||
}
|
||||
@ -64,6 +65,7 @@ async fn generate_runs(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
@ -73,9 +75,6 @@ async fn generate_runs(
|
||||
// Create a dummy sequence
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
|
||||
// TODO: Implement top_n_tokens
|
||||
let top_n_tokens= 0;
|
||||
|
||||
for b in batch_size {
|
||||
// Warmups on batch size
|
||||
for _ in 0..warmups {
|
||||
@ -135,7 +134,7 @@ async fn prefill(
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
parameters: NextTokenChooserParameters,
|
||||
top_n_tokens: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||
// Create requests
|
||||
@ -151,7 +150,7 @@ async fn prefill(
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens,
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -22,6 +22,7 @@ pub async fn run(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
@ -70,6 +71,7 @@ pub async fn run(
|
||||
batch_size.clone(),
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
parameters,
|
||||
@ -130,6 +132,7 @@ pub async fn run(
|
||||
tokenizer_name,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
@ -93,6 +93,11 @@ struct Args {
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
do_sample: bool,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
watermark,
|
||||
do_sample,
|
||||
master_shard_uds_path,
|
||||
top_n_tokens,
|
||||
} = args;
|
||||
|
||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
|
||||
tokenizer_name: String,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
|
||||
builder.push_record(["Model", &tokenizer_name]);
|
||||
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
||||
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
||||
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
|
||||
builder.push_record(["N Runs", &n_runs.to_string()]);
|
||||
builder.push_record(["Warmups", &warmups.to_string()]);
|
||||
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
||||
|
Loading…
Reference in New Issue
Block a user