diff --git a/assets/benchmark.png b/assets/benchmark.png new file mode 100644 index 00000000..6699f0e2 Binary files /dev/null and b/assets/benchmark.png differ diff --git a/benchmark/README.md b/benchmark/README.md index 2bc0d4d9..7f51a731 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,5 +1,11 @@ +
+ # Text Generation Inference benchmarking tool +![benchmark](../assets/benchmark.png) + +
+ A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha) and powered by [tui](https://github.com/tui-rs-revival/ratatui). diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 35c7e703..726ea3e4 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -19,6 +19,7 @@ pub(crate) struct App { completed_batch: usize, current_batch: usize, current_tab: usize, + touched_tab: bool, zoom: bool, is_error: bool, data: Data, @@ -53,6 +54,7 @@ impl App { completed_batch, current_batch, current_tab, + touched_tab: false, zoom: false, is_error, data, @@ -76,6 +78,7 @@ impl App { | KeyEvent { code: KeyCode::Tab, .. } => { + self.touched_tab=true; self.current_tab = (self.current_tab + 1) % self.batch_size.len(); } // Decrease and wrap tab @@ -83,6 +86,7 @@ impl App { code: KeyCode::Left, .. } => { + self.touched_tab=true; if self.current_tab > 0 { self.current_tab -= 1; } else { @@ -131,9 +135,14 @@ impl App { } Message::EndBatch => { self.data.end_batch(self.current_batch); - self.completed_batch += 1; + if self.current_batch < self.batch_size.len() - 1 { + // Only go to next tab if the user never touched the tab keys + if !self.touched_tab { + self.current_tab += 1; + } + self.current_batch += 1; } } diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 481be2e0..45de021a 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -14,19 +14,19 @@ use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { - #[clap(long, env)] + #[clap(short, long, env)] tokenizer_name: String, - #[clap(long)] + #[clap(short, long)] batch_size: Option>, - #[clap(default_value = "10", long, env)] + #[clap(default_value = "10", short, long, env)] sequence_length: u32, - #[clap(default_value = "64", long, env)] + #[clap(default_value = "8", short,long, env)] decode_length: u32, - #[clap(default_value = "10", long, env)] + #[clap(default_value = "10", short,long, env)] runs: usize, - #[clap(default_value = "1", long, env)] + #[clap(default_value = "1", short,long, env)] warmups: usize, - #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)] master_shard_uds_path: String, } diff --git a/router/src/queue.rs b/router/src/queue.rs index df2087e1..2899ccd4 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -237,6 +237,7 @@ mod tests { watermark: false, }, stopping_parameters: StoppingCriteriaParameters { + ignore_eos_token: false, max_new_tokens: 0, stop_sequences: vec![], },