From 1c5d52694341e80841f7ea269526362a0c3ca1d3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 29 Mar 2023 14:01:23 +0200 Subject: [PATCH] improvements --- benchmark/src/lib.rs | 34 ++++++++++++---------- benchmark/src/main.rs | 9 +++--- benchmark/src/ui.rs | 65 +++++++++++++++++++++++++------------------ 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index d3a167ac..d30745c1 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -44,6 +44,7 @@ pub(crate) enum Message { } pub async fn run( + tokenizer_name: String, tokenizer: Tokenizer, batch_size: Vec, sequence_length: u32, @@ -57,6 +58,9 @@ pub async fn run( tokio::spawn( UI { + tokenizer_name, + decode_length, + sequence_length, n_run: n_runs, batch_size: batch_size.clone(), receiver: ui_receiver, @@ -68,22 +72,22 @@ pub async fn run( let mut runs = Vec::with_capacity(batch_size.len() * n_runs); let sequence = create_sequence(sequence_length, tokenizer); - for _ in 0..warmups { - let (_, decode_batch) = tokio::select! { - res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - let _ = tokio::select! { - res = run_decode(decode_batch, sequence_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - } - for b in batch_size { + for _ in 0..warmups { + let (_, decode_batch) = tokio::select! { + res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?, + _ = shutdown_receiver.recv() => { + return Ok(()); + } + }; + let _ = tokio::select! { + res = run_decode(decode_batch, sequence_length, &mut client) => res?, + _ = shutdown_receiver.recv() => { + return Ok(()); + } + }; + } + for _ in 0..n_runs { let (prefill, decode_batch) = tokio::select! { res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?, diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 06941868..5d33b668 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -15,13 +15,13 @@ struct Args { tokenizer_name: String, #[clap(default_value = "1", long, env)] batch_size: Vec, - #[clap(default_value = "12", long, env)] - sequence_length: u32, #[clap(default_value = "10", long, env)] + sequence_length: u32, + #[clap(default_value = "64", long, env)] decode_length: u32, #[clap(default_value = "10", long, env)] runs: usize, - #[clap(default_value = "0", long, env)] + #[clap(default_value = "2", long, env)] warmups: usize, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, @@ -74,12 +74,13 @@ fn main() -> Result<(), Box> { .expect("Could not connect to server"); // Clear the cache; useful if the webserver rebooted sharded_client - .clear_cache() + .clear_cache(None) .await .expect("Unable to clear cache"); tracing::info!("Connected"); text_generation_benchmark::run( + tokenizer_name, tokenizer, batch_size, sequence_length, diff --git a/benchmark/src/ui.rs b/benchmark/src/ui.rs index 2b9125c4..adea0e23 100644 --- a/benchmark/src/ui.rs +++ b/benchmark/src/ui.rs @@ -17,6 +17,9 @@ use tui::widgets::{ use tui::{symbols, Terminal}; pub(crate) struct UI { + pub(crate) tokenizer_name: String, + pub(crate) sequence_length: u32, + pub(crate) decode_length: u32, pub(crate) n_run: usize, pub(crate) batch_size: Vec, pub(crate) receiver: mpsc::Receiver, @@ -117,10 +120,11 @@ impl UI { terminal.draw(|f| { // Vertical layout - let row4 = Layout::default() + let row5 = Layout::default() .direction(Direction::Vertical) .constraints( [ + Constraint::Length(1), Constraint::Length(3), Constraint::Length(3), Constraint::Length(13), @@ -134,7 +138,7 @@ impl UI { let top = Layout::default() .direction(Direction::Horizontal) .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) - .split(row4[0]); + .split(row5[2]); // Mid row horizontal layout let mid = Layout::default() @@ -148,7 +152,7 @@ impl UI { ] .as_ref(), ) - .split(row4[2]); + .split(row5[3]); // Left mid row vertical layout let prefill_text = Layout::default() @@ -166,7 +170,36 @@ impl UI { let bottom = Layout::default() .direction(Direction::Horizontal) .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) - .split(row4[3]); + .split(row5[4]); + + // Title + let title = Block::default().borders(Borders::NONE).title(format!( + "Model: {} | Sequence Length: {} | Decode Length: {}", + self.tokenizer_name, self.sequence_length, self.decode_length + )).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White)); + f.render_widget(title, row5[0]); + + // Batch tabs + let titles = self + .batch_size + .iter() + .map(|b| { + Spans::from(vec![Span::styled( + format!("Batch: {b}"), + Style::default().fg(Color::White), + )]) + }) + .collect(); + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::ALL).title("Tabs")) + .select(current_tab_idx) + .style(Style::default().fg(Color::LightCyan)) + .highlight_style( + Style::default() + .add_modifier(Modifier::BOLD) + .bg(Color::Black), + ); + f.render_widget(tabs, row5[1]); // Total progress bar let batch_gauge = progress_gauge( @@ -186,28 +219,6 @@ impl UI { ); f.render_widget(run_gauge, top[1]); - // Batch tabs - let titles = self - .batch_size - .iter() - .map(|b| { - Spans::from(vec![ - Span::raw(format!("Batch: {b}")), // Span::styled(first, Style::default().fg(Color::Yellow)), - // Span::styled(rest, Style::default().fg(Color::Green)), - ]) - }) - .collect(); - let tabs = Tabs::new(titles) - .block(Block::default().borders(Borders::ALL).title("Tabs")) - .select(current_tab_idx) - .style(Style::default().fg(Color::LightCyan)) - .highlight_style( - Style::default() - .add_modifier(Modifier::BOLD) - .bg(Color::Black), - ); - f.render_widget(tabs, row4[1]); - // Prefill text infos let (prefill_latency_statics, prefill_throughput_statics) = text_info( &mut prefill_latencies[current_tab_idx], @@ -384,7 +395,7 @@ fn latency_histogram<'a>( .block( Block::default() .title(format!("{name} latency histogram")) - .style(Style::default().fg(Color::Yellow).bg(Color::Reset)) + .style(Style::default().fg(Color::LightYellow).bg(Color::Reset)) .borders(Borders::ALL), ) .data(histo_data_str.as_slice())