From b6df2036ed7743f46ca360bc776d928972d1bf4c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 30 Mar 2023 12:36:17 +0200 Subject: [PATCH] v1 --- Makefile | 3 + benchmark/Cargo.toml | 2 +- benchmark/README.md | 24 +++ benchmark/src/{ui.rs => app.rs} | 265 ++++++++++++++++++-------------- benchmark/src/event.rs | 16 +- benchmark/src/generation.rs | 59 ++++--- benchmark/src/lib.rs | 49 +++--- benchmark/src/main.rs | 16 +- 8 files changed, 265 insertions(+), 169 deletions(-) create mode 100644 benchmark/README.md rename benchmark/src/{ui.rs => app.rs} (91%) diff --git a/Makefile b/Makefile index 3defd886..21fd11b5 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,9 @@ install-router: install-launcher: cd launcher && cargo install --path . +install-benchmark: + cd benchmark && cargo install --path . + install: install-server install-router install-launcher server-dev: diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index ad67896c..4c1defda 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -9,7 +9,7 @@ description = "Text Generation Benchmarking tool" path = "src/lib.rs" [[bin]] -name = "text-generation-bench" +name = "text-generation-benchmark" path = "src/main.rs" [dependencies] diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..2bc0d4d9 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,24 @@ +# Text Generation Inference benchmarking tool + +A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha) +and powered by [tui](https://github.com/tui-rs-revival/ratatui). + +## Install + +```shell +make install-benchmark +``` + +## Run + +First, start `text-generation-inference`: + +```shell +text-generation-launcher --model-id bigscience/bloom-560m +``` + +Then run the benchmarking tool: + +```shell +text-generation-benchmark --tokenizer-name bigscience/bloom-560m +``` \ No newline at end of file diff --git a/benchmark/src/ui.rs b/benchmark/src/app.rs similarity index 91% rename from benchmark/src/ui.rs rename to benchmark/src/app.rs index e824ace7..35c7e703 100644 --- a/benchmark/src/ui.rs +++ b/benchmark/src/app.rs @@ -1,5 +1,5 @@ +/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs use crate::generation::{Decode, Message, Prefill}; -/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; use text_generation_client::ClientError; use tokio::sync::mpsc; @@ -12,70 +12,8 @@ use tui::widgets::{ }; use tui::{symbols, Frame}; -struct Data { - prefill_latencies: Vec>, - prefill_throughputs: Vec>, - decode_latencies: Vec>, - decode_throughputs: Vec>, - prefill_batch_latency_throughput: Vec<(f64, f64)>, - decode_batch_latency_throughput: Vec<(f64, f64)>, -} - -impl Data { - fn new(n_run: usize, n_batch: usize) -> Self { - let prefill_latencies: Vec> = - (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); - let prefill_throughputs: Vec> = - (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); - - let decode_latencies: Vec> = - (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); - let decode_throughputs: Vec> = - (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); - - let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); - - let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); - - Self { - prefill_latencies, - prefill_throughputs, - decode_latencies, - decode_throughputs, - prefill_batch_latency_throughput, - decode_batch_latency_throughput, - } - } - - fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) { - let latency = prefill.latency.as_millis() as f64; - self.prefill_latencies[batch_idx].push(latency); - self.prefill_throughputs[batch_idx].push(prefill.throughput); - } - - fn push_decode(&mut self, prefill: Decode, batch_idx: usize) { - let latency = prefill.latency.as_millis() as f64; - self.decode_latencies[batch_idx].push(latency); - self.decode_throughputs[batch_idx].push(prefill.throughput); - } - - fn end_batch(&mut self, batch_idx: usize) { - self.prefill_batch_latency_throughput.push(( - self.prefill_latencies[batch_idx].iter().sum::() - / self.prefill_latencies[batch_idx].len() as f64, - self.prefill_throughputs[batch_idx].iter().sum::() - / self.prefill_throughputs[batch_idx].len() as f64, - )); - self.decode_batch_latency_throughput.push(( - self.decode_latencies[batch_idx].iter().sum::() - / self.decode_latencies[batch_idx].len() as f64, - self.decode_throughputs[batch_idx].iter().sum::() - / self.decode_throughputs[batch_idx].len() as f64, - )); - } -} - -pub(crate) struct UI { +/// TUI powered App +pub(crate) struct App { pub(crate) running: bool, completed_runs: Vec, completed_batch: usize, @@ -92,7 +30,7 @@ pub(crate) struct UI { receiver: mpsc::Receiver>, } -impl UI { +impl App { pub(crate) fn new( receiver: mpsc::Receiver>, tokenizer_name: String, @@ -127,18 +65,20 @@ impl UI { } } + /// Handle crossterm key events pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) { match key_event { + // Increase and wrap tab KeyEvent { code: KeyCode::Right, .. - } | - KeyEvent { - code: KeyCode::Tab, - .. + } + | KeyEvent { + code: KeyCode::Tab, .. } => { self.current_tab = (self.current_tab + 1) % self.batch_size.len(); } + // Decrease and wrap tab KeyEvent { code: KeyCode::Left, .. @@ -149,19 +89,21 @@ impl UI { self.current_tab = self.batch_size.len() - 1; } } + // Zoom on throughput/latency fig KeyEvent { code: KeyCode::Char('+'), .. } => { self.zoom = true; } + // Unzoom on throughput/latency fig KeyEvent { code: KeyCode::Char('-'), .. } => { self.zoom = false; } - + // Quit KeyEvent { code: KeyCode::Char('q'), .. @@ -177,13 +119,14 @@ impl UI { } } + /// Get all pending messages from generation task pub(crate) fn tick(&mut self) { while let Ok(message) = self.receiver.try_recv() { match message { Ok(message) => match message { Message::Prefill(step) => self.data.push_prefill(step, self.current_batch), Message::Decode(step) => self.data.push_decode(step, self.current_batch), - Message::Run(_) => { + Message::EndRun => { self.completed_runs[self.current_batch] += 1; } Message::EndBatch => { @@ -201,6 +144,7 @@ impl UI { } } + /// Render frame pub fn render(&mut self, f: &mut Frame<'_, B>) { let batch_progress = (self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0); @@ -218,7 +162,7 @@ impl UI { Constraint::Length(13), Constraint::Min(10), ] - .as_ref(), + .as_ref(), ) .split(f.size()); @@ -238,7 +182,7 @@ impl UI { Constraint::Percentage(20), Constraint::Percentage(30), ] - .as_ref(), + .as_ref(), ) .split(row5[3]); @@ -277,14 +221,9 @@ impl UI { // Helper let helper = Block::default() .borders(Borders::NONE) - .title(format!( - "<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom" - )) + .title("<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom") .title_alignment(Alignment::Right) - .style( - Style::default() - .fg(Color::White), - ); + .style(Style::default().fg(Color::White)); f.render_widget(helper, row5[0]); // Batch tabs @@ -356,7 +295,7 @@ impl UI { } else { (mid[1].width as usize - 2) / (histo_width + 1) } - .max(2); + .max(2); let histo_data = latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins); @@ -404,6 +343,71 @@ impl UI { } } +/// App internal data struct +struct Data { + prefill_latencies: Vec>, + prefill_throughputs: Vec>, + decode_latencies: Vec>, + decode_throughputs: Vec>, + prefill_batch_latency_throughput: Vec<(f64, f64)>, + decode_batch_latency_throughput: Vec<(f64, f64)>, +} + +impl Data { + fn new(n_run: usize, n_batch: usize) -> Self { + let prefill_latencies: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + let prefill_throughputs: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + + let decode_latencies: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + let decode_throughputs: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + + let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); + + let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); + + Self { + prefill_latencies, + prefill_throughputs, + decode_latencies, + decode_throughputs, + prefill_batch_latency_throughput, + decode_batch_latency_throughput, + } + } + + fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) { + let latency = prefill.latency.as_millis() as f64; + self.prefill_latencies[batch_idx].push(latency); + self.prefill_throughputs[batch_idx].push(prefill.throughput); + } + + fn push_decode(&mut self, prefill: Decode, batch_idx: usize) { + let latency = prefill.latency.as_millis() as f64; + self.decode_latencies[batch_idx].push(latency); + self.decode_throughputs[batch_idx].push(prefill.throughput); + } + + fn end_batch(&mut self, batch_idx: usize) { + self.prefill_batch_latency_throughput.push(( + self.prefill_latencies[batch_idx].iter().sum::() + / self.prefill_latencies[batch_idx].len() as f64, + self.prefill_throughputs[batch_idx].iter().sum::() + / self.prefill_throughputs[batch_idx].len() as f64, + )); + self.decode_batch_latency_throughput.push(( + self.decode_latencies[batch_idx].iter().sum::() + / self.decode_latencies[batch_idx].len() as f64, + self.decode_throughputs[batch_idx].iter().sum::() + / self.decode_throughputs[batch_idx].len() as f64, + )); + } +} + +/// Progress bar fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge { Gauge::default() .block(Block::default().title(title).borders(Borders::ALL)) @@ -412,31 +416,40 @@ fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Ga .ratio(progress) } +/// Prefill or Decode text infos fn text_info<'a>( latency: &mut Vec, throughput: &Vec, name: &'static str, ) -> (Paragraph<'a>, Paragraph<'a>) { - let mut latency_texts = statis_spans(&latency, "ms"); + // Latency average/high/low texts + let mut latency_texts = statis_spans(latency, "ms"); + + // Sort latency for percentiles float_ord::sort(latency); let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]); + + // Latency p50/p90/p99 texts let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed]; for (i, (name, value)) in latency_percentiles.iter().enumerate() { let span = Spans::from(vec![Span::styled( - format!("{name}: {:.4} ms", value), + format!("{name}: {value:.4} ms"), Style::default().fg(colors[i]), )]); latency_texts.push(span); } - let throughput_texts = statis_spans(&throughput, "tokens/secs"); + // Throughput average/high/low texts + let throughput_texts = statis_spans(throughput, "tokens/secs"); + // Latency Block let latency_statics = Paragraph::new(latency_texts).block( Block::default() .title(Span::raw(format!("{name} Latency"))) .borders(Borders::ALL), ); + // Throughput block let throughput_statics = Paragraph::new(throughput_texts).block( Block::default() .title(Span::raw(format!("{name} Throughput"))) @@ -446,32 +459,7 @@ fn text_info<'a>( (latency_statics, throughput_statics) } -fn latency_histogram_data(latency: &Vec, bins: usize) -> Vec<(String, u64)> { - let histo_data: Vec<(String, u64)> = { - let histo = crate::utils::histogram(latency, bins); - histo - .into_iter() - .map(|(label, v)| (format!("{label:.2}"), v as u64)) - .collect() - }; - - histo_data -} - -fn latency_histogram<'a>( - histo_data_str: &'a Vec<(&'a str, u64)>, - name: &'static str, -) -> BarChart<'a> { - BarChart::default() - .block( - Block::default() - .title(format!("{name} latency histogram")) - .style(Style::default().fg(Color::LightYellow).bg(Color::Reset)) - .borders(Borders::ALL), - ) - .data(histo_data_str.as_slice()) -} - +/// Average/High/Low spans fn statis_spans<'a>(data: &Vec, unit: &'static str) -> Vec> { vec![ Spans::from(vec![Span::styled( @@ -502,15 +490,45 @@ fn statis_spans<'a>(data: &Vec, unit: &'static str) -> Vec> { ] } +/// Latency histogram data +fn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> { + let histo_data: Vec<(String, u64)> = { + let histo = crate::utils::histogram(latency, bins); + histo + .into_iter() + .map(|(label, v)| (format!("{label:.2}"), v as u64)) + .collect() + }; + + histo_data +} + +/// Latency Histogram +fn latency_histogram<'a>( + histo_data_str: &'a Vec<(&'a str, u64)>, + name: &'static str, +) -> BarChart<'a> { + BarChart::default() + .block( + Block::default() + .title(format!("{name} latency histogram")) + .style(Style::default().fg(Color::LightYellow).bg(Color::Reset)) + .borders(Borders::ALL), + ) + .data(histo_data_str.as_slice()) +} + +/// Latency/Throughput chart fn latency_throughput_chart<'a>( latency_throughput: &'a Vec<(f64, f64)>, - batch_sizes: &'a Vec, + batch_sizes: &'a [u32], zoom: bool, name: &'static str, ) -> Chart<'a> { let latency_iter = latency_throughput.iter().map(|(l, _)| l); let throughput_iter = latency_throughput.iter().map(|(_, t)| t); + // Get extreme values let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) @@ -526,6 +544,7 @@ fn latency_throughput_chart<'a>( .max_by(|a, b| a.total_cmp(b)) .unwrap_or(&std::f64::NAN); + // Char min max values let min_x = if zoom { ((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0 } else { @@ -534,6 +553,7 @@ fn latency_throughput_chart<'a>( let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0; let step_x = (max_x - min_x) / 4.0; + // Chart min max values let min_y = if zoom { ((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0 } else { @@ -542,8 +562,9 @@ fn latency_throughput_chart<'a>( let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0; let step_y = (max_y - min_y) / 4.0; + // Labels let mut x_labels = vec![Span::styled( - format!("{:.2}", min_x), + format!("{min_x:.2}"), Style::default() .add_modifier(Modifier::BOLD) .fg(Color::Gray) @@ -556,15 +577,16 @@ fn latency_throughput_chart<'a>( )); } x_labels.push(Span::styled( - format!("{:.2}", max_x), + format!("{max_x:.2}"), Style::default() .add_modifier(Modifier::BOLD) .fg(Color::Gray) .bg(Color::Reset), )); + // Labels let mut y_labels = vec![Span::styled( - format!("{:.2}", min_y), + format!("{min_y:.2}"), Style::default() .add_modifier(Modifier::BOLD) .fg(Color::Gray) @@ -577,25 +599,29 @@ fn latency_throughput_chart<'a>( )); } y_labels.push(Span::styled( - format!("{:.2}", max_y), + format!("{max_y:.2}"), Style::default() .add_modifier(Modifier::BOLD) .fg(Color::Gray) .bg(Color::Reset), )); + // Chart dataset let colors = color_vec(); let datasets: Vec = (0..latency_throughput.len()) .map(|i| { + let color_idx = i % colors.len(); + Dataset::default() .name(batch_sizes[i].to_string()) .marker(symbols::Marker::Block) - .style(Style::default().fg(colors[i])) + .style(Style::default().fg(colors[color_idx])) .graph_type(GraphType::Scatter) .data(&latency_throughput[i..(i + 1)]) }) .collect(); + // Chart Chart::new(datasets) .style(Style::default().fg(Color::Cyan).bg(Color::Reset)) .block( @@ -608,20 +634,21 @@ fn latency_throughput_chart<'a>( ) .x_axis( Axis::default() - .title(format!("ms")) + .title("ms") .style(Style::default().fg(Color::Gray).bg(Color::Reset)) .labels(x_labels) .bounds([min_x, max_x]), ) .y_axis( Axis::default() - .title(format!("tokens/secs")) + .title("tokens/secs") .style(Style::default().fg(Color::Gray).bg(Color::Reset)) .labels(y_labels) .bounds([min_y, max_y]), ) } +// Colors for latency/throughput chart fn color_vec() -> Vec { vec![ Color::Red, diff --git a/benchmark/src/event.rs b/benchmark/src/event.rs index 32d63fc8..91ce8400 100644 --- a/benchmark/src/event.rs +++ b/benchmark/src/event.rs @@ -1,4 +1,4 @@ -/// Inspired by https://github.com/orhun/rust-tui-template +/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs use crossterm::event; use std::time::{Duration, Instant}; use tokio::sync::{broadcast, mpsc}; @@ -20,6 +20,8 @@ pub(crate) async fn terminal_event_task( mut shutdown_receiver: broadcast::Receiver<()>, _shutdown_guard_sender: mpsc::Sender<()>, ) { + // End task if a message is received on shutdown_receiver + // _shutdown_guard_sender will be dropped once the task is finished tokio::select! { _ = event_loop(fps, event_sender) => { }, @@ -27,14 +29,21 @@ pub(crate) async fn terminal_event_task( } } +/// Main event loop async fn event_loop(fps: u32, event_sender: mpsc::Sender) { - let per_frame = Duration::from_secs(1) / fps as u32; + // Frame budget + let per_frame = Duration::from_secs(1) / fps; + + // When was last frame executed let mut last_frame = Instant::now(); + loop { + // Sleep to avoid blocking the thread for too long if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) { tokio::time::sleep(sleep).await; } + // Get crossterm event and send a new one over the channel if event::poll(Duration::from_secs(0)).expect("no events available") { match event::read().expect("unable to read event") { event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()), @@ -45,8 +54,11 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender) { } } + // Frame budget exceeded if last_frame.elapsed() >= per_frame { + // Send tick event_sender.send(Event::Tick).await.unwrap_or(()); + // Rest last_frame time last_frame = Instant::now(); } } diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index eb3b9201..3bdff1d1 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -16,28 +16,21 @@ pub(crate) struct Prefill { #[derive(Debug, Clone)] pub(crate) struct Decode { - pub(crate) decode_length: u32, pub(crate) latency: Duration, pub(crate) throughput: f64, } -#[derive(Debug)] -pub(crate) struct Run { - pub(crate) batch_size: u32, - pub(crate) sequence_length: u32, - pub(crate) prefill: Prefill, - pub(crate) decode: Decode, -} - #[derive(Debug)] pub(crate) enum Message { Warmup, Prefill(Prefill), Decode(Decode), - Run(Run), + EndRun, EndBatch, } +/// Benchmarking task +#[allow(clippy::too_many_arguments)] pub(crate) async fn generation_task( tokenizer: Tokenizer, batch_size: Vec, @@ -50,6 +43,8 @@ pub(crate) async fn generation_task( mut shutdown_receiver: broadcast::Receiver<()>, _shutdown_guard_sender: mpsc::Sender<()>, ) { + // End task if a message is received on shutdown_receiver + // _shutdown_guard_sender will be dropped once the task is finished tokio::select! { res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => { if let Err(err) = res { @@ -60,6 +55,8 @@ pub(crate) async fn generation_task( } } +/// Benchmark prefill/decode +#[allow(clippy::too_many_arguments)] async fn generate_runs( tokenizer: Tokenizer, batch_size: Vec, @@ -70,52 +67,53 @@ async fn generate_runs( mut client: ShardedClient, run_sender: mpsc::Sender>, ) -> Result<(), ClientError> { + // Create a dummy sequence let sequence = create_sequence(sequence_length, tokenizer); for b in batch_size { + // Warmups on batch size for _ in 0..warmups { let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; let _ = decode(decode_batch, &mut client).await?; + // Send warmup message run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); } for _ in 0..n_runs { let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; + // Send prefill message run_sender - .send(Ok(Message::Prefill(prefill.clone()))) + .send(Ok(Message::Prefill(prefill))) .await .unwrap_or(()); let decode = decode(decode_batch, &mut client).await?; + // Send decode message run_sender - .send(Ok(Message::Decode(decode.clone()))) + .send(Ok(Message::Decode(decode))) .await .unwrap_or(()); - run_sender - .send(Ok(Message::Run(Run { - batch_size: b, - sequence_length, - prefill, - decode, - }))) - .await - .unwrap_or(()); + // Send run ended message + run_sender.send(Ok(Message::EndRun)).await.unwrap_or(()); } + // Batch ended run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(()); } Ok(()) } +// Run a prefill step async fn prefill( sequence: String, batch_size: u32, decode_length: u32, client: &mut ShardedClient, ) -> Result<(Prefill, Batch), ClientError> { + // Create requests let requests = (0..batch_size) .map(|id| Request { id: id.into(), @@ -133,7 +131,7 @@ async fn prefill( stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, stop_sequences: vec![], - ignore_eos_token: true, + ignore_eos_token: true, // Will not stop even if a eos token is generated }), }) .collect(); @@ -144,11 +142,17 @@ async fn prefill( size: batch_size, }; + // Run prefill let start_time = Instant::now(); let (_, decode_batch) = client.prefill(batch.clone()).await?; + + // Get latency let latency = start_time.elapsed(); + + // Compute throughput from latency and batch size let throughput = batch_size as f64 / latency.as_secs_f64(); + // Decode batch cannot be empty let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); let step = Prefill { @@ -159,28 +163,35 @@ async fn prefill( Ok((step, decode_batch)) } +/// Run a full decode async fn decode(batch: Batch, client: &mut ShardedClient) -> Result { let mut decode_length = 0; - let start_time = Instant::now(); let batch_size = batch.size; + let start_time = Instant::now(); + + // Full decode over decode length let mut next_batch = Some(batch); while let Some(batch) = next_batch { let result = client.decode(vec![batch]).await?; next_batch = result.1; decode_length += 1; } + + // Get latency let latency = start_time.elapsed(); + + // Compute throughput from latency, batch size and decode length let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64(); let step = Decode { - decode_length, latency, throughput, }; Ok(step) } +/// Create a dummy sequence of the correct length fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len(); // Repeat lorem ipsum to cover sequence length diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 60de542a..4da0b573 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -1,12 +1,10 @@ -extern crate core; - +mod app; mod event; mod generation; -mod ui; mod utils; +use crate::app::App; use crate::event::Event; -use crate::ui::UI; use crossterm::ExecutableCommand; use std::io; use text_generation_client::ShardedClient; @@ -15,6 +13,8 @@ use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; use tui::Terminal; +/// Run benchmarking app +#[allow(clippy::too_many_arguments)] pub async fn run( tokenizer_name: String, tokenizer: Tokenizer, @@ -25,11 +25,27 @@ pub async fn run( warmups: usize, client: ShardedClient, ) -> Result<(), crossterm::ErrorKind> { + // Initialize terminal properties + crossterm::terminal::enable_raw_mode()?; + io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; + io::stdout().execute(crossterm::cursor::Hide)?; + + // Initialize terminal + let mut terminal = { + let backend = CrosstermBackend::new(io::stdout()); + Terminal::new(backend)? + }; + + // Create message channel between generation_task and app let (run_sender, run_receiver) = mpsc::channel(8); + // Crossterm event channel let (event_sender, mut event_receiver) = mpsc::channel(8); + // Shutdown channel to terminate tasks let (shutdown_sender, _) = broadcast::channel(1); + // Channel to check if tasks terminated let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1); + // Create generation task tokio::spawn(generation::generation_task( tokenizer, batch_size.clone(), @@ -43,6 +59,7 @@ pub async fn run( shutdown_guard_sender.clone(), )); + // Create event task tokio::spawn(event::terminal_event_task( 250, event_sender, @@ -50,9 +67,11 @@ pub async fn run( shutdown_guard_sender.clone(), )); + // Drop our end of shutdown sender drop(shutdown_guard_sender); - let mut ui = UI::new( + // Create App + let mut app = App::new( run_receiver, tokenizer_name, sequence_length, @@ -61,23 +80,17 @@ pub async fn run( batch_size, ); - crossterm::terminal::enable_raw_mode()?; - io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; - io::stdout().execute(crossterm::cursor::Hide)?; - - let mut terminal = { - let backend = CrosstermBackend::new(io::stdout()); - Terminal::new(backend)? - }; - - while ui.running { - terminal.draw(|frame| ui.render(frame))?; + while app.running { + // Draw frame + terminal.draw(|frame| app.render(frame))?; + // Await a new event from event handling task match event_receiver.recv().await { None => break, + // Update app state Some(event) => match event { - Event::Tick => ui.tick(), - Event::Key(key_event) => ui.handle_key_event(key_event), + Event::Tick => app.tick(), + Event::Key(key_event) => app.handle_key_event(key_event), _ => {} }, } diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 5d33b668..481be2e0 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -1,5 +1,8 @@ -use clap::Parser; /// Text Generation Inference benchmarking tool +/// +/// Inspired by the great Oha app: https://github.com/hatoo/oha +/// and: https://github.com/orhun/rust-tui-template +use clap::Parser; use std::path::Path; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; @@ -11,17 +14,17 @@ use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { - #[clap(default_value = "bigscience/bloom", long, env)] + #[clap(long, env)] tokenizer_name: String, - #[clap(default_value = "1", long, env)] - batch_size: Vec, + #[clap(long)] + batch_size: Option>, #[clap(default_value = "10", long, env)] sequence_length: u32, #[clap(default_value = "64", long, env)] decode_length: u32, #[clap(default_value = "10", long, env)] runs: usize, - #[clap(default_value = "2", long, env)] + #[clap(default_value = "1", long, env)] warmups: usize, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, @@ -41,6 +44,8 @@ fn main() -> Result<(), Box> { master_shard_uds_path, } = args; + let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); + init_logging(); // Tokenizer instance @@ -79,6 +84,7 @@ fn main() -> Result<(), Box> { .expect("Unable to clear cache"); tracing::info!("Connected"); + // Run app text_generation_benchmark::run( tokenizer_name, tokenizer,