mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add top-n-tokens support to benchmark
This commit is contained in:
parent
8a4d2076a6
commit
0facd94738
@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
|
|||||||
// End task if a message is received on shutdown_receiver
|
// End task if a message is received on shutdown_receiver
|
||||||
// _shutdown_guard_sender will be dropped once the task is finished
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||||
if let Err(err) = res {
|
if let Err(err) = res {
|
||||||
run_sender.send(Err(err)).await.unwrap_or(());
|
run_sender.send(Err(err)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
@ -64,6 +65,7 @@ async fn generate_runs(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
@ -73,9 +75,6 @@ async fn generate_runs(
|
|||||||
// Create a dummy sequence
|
// Create a dummy sequence
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
// TODO: Implement top_n_tokens
|
|
||||||
let top_n_tokens= 0;
|
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
// Warmups on batch size
|
// Warmups on batch size
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
@ -135,7 +134,7 @@ async fn prefill(
|
|||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
top_n_tokens: u32,
|
top_n_tokens: Option<u32>,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||||
// Create requests
|
// Create requests
|
||||||
@ -151,7 +150,7 @@ async fn prefill(
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
top_n_tokens: top_n_tokens,
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ pub async fn run(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
@ -70,6 +71,7 @@ pub async fn run(
|
|||||||
batch_size.clone(),
|
batch_size.clone(),
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
n_runs,
|
n_runs,
|
||||||
warmups,
|
warmups,
|
||||||
parameters,
|
parameters,
|
||||||
@ -130,6 +132,7 @@ pub async fn run(
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
n_runs,
|
n_runs,
|
||||||
warmups,
|
warmups,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -93,6 +93,11 @@ struct Args {
|
|||||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
watermark,
|
watermark,
|
||||||
do_sample,
|
do_sample,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
top_n_tokens,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||||
@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
batch_size,
|
batch_size,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
runs,
|
runs,
|
||||||
warmups,
|
warmups,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
|
|||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
|
|||||||
builder.push_record(["Model", &tokenizer_name]);
|
builder.push_record(["Model", &tokenizer_name]);
|
||||||
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
||||||
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
||||||
|
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
|
||||||
builder.push_record(["N Runs", &n_runs.to_string()]);
|
builder.push_record(["N Runs", &n_runs.to_string()]);
|
||||||
builder.push_record(["Warmups", &warmups.to_string()]);
|
builder.push_record(["Warmups", &warmups.to_string()]);
|
||||||
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
||||||
|
Loading…
Reference in New Issue
Block a user