From 271f045825cb2b84443e3717cd057d7c5367c9de Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 30 Mar 2023 11:44:00 +0200 Subject: [PATCH] improving design --- benchmark/src/event.rs | 17 +- benchmark/src/generation.rs | 54 +-- benchmark/src/lib.rs | 89 +++-- benchmark/src/ui.rs | 649 +++++++++++++++++++----------------- 4 files changed, 441 insertions(+), 368 deletions(-) diff --git a/benchmark/src/event.rs b/benchmark/src/event.rs index e273a9a7..32d63fc8 100644 --- a/benchmark/src/event.rs +++ b/benchmark/src/event.rs @@ -1,7 +1,7 @@ /// Inspired by https://github.com/orhun/rust-tui-template use crossterm::event; -use tokio::sync::{mpsc, broadcast}; use std::time::{Duration, Instant}; +use tokio::sync::{broadcast, mpsc}; /// Events #[derive(Debug)] @@ -14,9 +14,11 @@ pub(crate) enum Event { Resize(u16, u16), } -pub(crate) async fn terminal_event_task(fps: u32, event_sender: mpsc::Sender, - mut shutdown_receiver: broadcast::Receiver<()>, - _shutdown_guard_sender: mpsc::Sender<()>, +pub(crate) async fn terminal_event_task( + fps: u32, + event_sender: mpsc::Sender, + mut shutdown_receiver: broadcast::Receiver<()>, + _shutdown_guard_sender: mpsc::Sender<()>, ) { tokio::select! { _ = event_loop(fps, event_sender) => { @@ -25,8 +27,7 @@ pub(crate) async fn terminal_event_task(fps: u32, event_sender: mpsc::Sender, -) { +async fn event_loop(fps: u32, event_sender: mpsc::Sender) { let per_frame = Duration::from_secs(1) / fps as u32; let mut last_frame = Instant::now(); loop { @@ -37,7 +38,9 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender, if event::poll(Duration::from_secs(0)).expect("no events available") { match event::read().expect("unable to read event") { event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()), - event::Event::Resize(w, h) => event_sender.send(Event::Resize(w, h)).await.unwrap_or(()), + event::Event::Resize(w, h) => { + event_sender.send(Event::Resize(w, h)).await.unwrap_or(()) + } _ => (), } } diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 024b1320..eb3b9201 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,5 +1,8 @@ use std::time::{Duration, Instant}; -use text_generation_client::{Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters}; +use text_generation_client::{ + Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, +}; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; @@ -57,26 +60,29 @@ pub(crate) async fn generation_task( } } -async fn generate_runs(tokenizer: Tokenizer, - batch_size: Vec, - sequence_length: u32, - decode_length: u32, - n_runs: usize, - warmups: usize, - mut client: ShardedClient, - run_sender: mpsc::Sender>, +async fn generate_runs( + tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + n_runs: usize, + warmups: usize, + mut client: ShardedClient, + run_sender: mpsc::Sender>, ) -> Result<(), ClientError> { let sequence = create_sequence(sequence_length, tokenizer); for b in batch_size { for _ in 0..warmups { - let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; + let (_, decode_batch) = + prefill(sequence.clone(), b, decode_length, &mut client).await?; let _ = decode(decode_batch, &mut client).await?; run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); } for _ in 0..n_runs { - let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; + let (prefill, decode_batch) = + prefill(sequence.clone(), b, decode_length, &mut client).await?; run_sender .send(Ok(Message::Prefill(prefill.clone()))) .await @@ -89,12 +95,15 @@ async fn generate_runs(tokenizer: Tokenizer, .await .unwrap_or(()); - run_sender.send(Ok(Message::Run(Run { - batch_size: b, - sequence_length, - prefill, - decode, - }))).await.unwrap_or(()); + run_sender + .send(Ok(Message::Run(Run { + batch_size: b, + sequence_length, + prefill, + decode, + }))) + .await + .unwrap_or(()); } run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(()); } @@ -138,8 +147,7 @@ async fn prefill( let start_time = Instant::now(); let (_, decode_batch) = client.prefill(batch.clone()).await?; let latency = start_time.elapsed(); - let throughput = batch_size as f64 - / latency.as_secs_f64(); + let throughput = batch_size as f64 / latency.as_secs_f64(); let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); @@ -151,10 +159,7 @@ async fn prefill( Ok((step, decode_batch)) } -async fn decode( - batch: Batch, - client: &mut ShardedClient, -) -> Result { +async fn decode(batch: Batch, client: &mut ShardedClient) -> Result { let mut decode_length = 0; let start_time = Instant::now(); let batch_size = batch.size; @@ -166,8 +171,7 @@ async fn decode( decode_length += 1; } let latency = start_time.elapsed(); - let throughput = (batch_size * decode_length) as f64 - / latency.as_secs_f64(); + let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64(); let step = Decode { decode_length, diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 61acc331..60de542a 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -1,15 +1,19 @@ extern crate core; +mod event; +mod generation; mod ui; mod utils; -mod generation; -mod event; +use crate::event::Event; use crate::ui::UI; +use crossterm::ExecutableCommand; +use std::io; +use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; -use text_generation_client::ShardedClient; - +use tui::backend::CrosstermBackend; +use tui::Terminal; pub async fn run( tokenizer_name: String, @@ -20,33 +24,74 @@ pub async fn run( n_runs: usize, warmups: usize, client: ShardedClient, -) -> Result<(), Box> { +) -> Result<(), crossterm::ErrorKind> { let (run_sender, run_receiver) = mpsc::channel(8); - let (shutdown_sender, shutdown_receiver) = broadcast::channel(1); + let (event_sender, mut event_receiver) = mpsc::channel(8); + let (shutdown_sender, _) = broadcast::channel(1); let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1); - tokio::spawn( - generation::generation_task(tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, client, run_sender, shutdown_receiver, shutdown_guard_sender.clone()), + tokio::spawn(generation::generation_task( + tokenizer, + batch_size.clone(), + sequence_length, + decode_length, + n_runs, + warmups, + client, + run_sender, + shutdown_sender.subscribe(), + shutdown_guard_sender.clone(), + )); + + tokio::spawn(event::terminal_event_task( + 250, + event_sender, + shutdown_sender.subscribe(), + shutdown_guard_sender.clone(), + )); + + drop(shutdown_guard_sender); + + let mut ui = UI::new( + run_receiver, + tokenizer_name, + sequence_length, + decode_length, + n_runs, + batch_size, ); - tokio::spawn( - UI { - tokenizer_name, - decode_length, - sequence_length, - n_run: n_runs, - batch_size: batch_size, - receiver: run_receiver, - shutdown_sender, - _shutdown_guard_sender: shutdown_guard_sender.clone() + crossterm::terminal::enable_raw_mode()?; + io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; + io::stdout().execute(crossterm::cursor::Hide)?; + + let mut terminal = { + let backend = CrosstermBackend::new(io::stdout()); + Terminal::new(backend)? + }; + + while ui.running { + terminal.draw(|frame| ui.render(frame))?; + + match event_receiver.recv().await { + None => break, + Some(event) => match event { + Event::Tick => ui.tick(), + Event::Key(key_event) => ui.handle_key_event(key_event), + _ => {} + }, } - .draw(), - ); - - drop (shutdown_guard_sender); + } + // Ask tasks to shutdown + let _ = shutdown_sender.send(()); // Wait for tasks to shutdown let _ = shutdown_guard_receiver.recv().await; + // Revert terminal to original view + io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; + crossterm::terminal::disable_raw_mode()?; + io::stdout().execute(crossterm::cursor::Show)?; + Ok(()) } diff --git a/benchmark/src/ui.rs b/benchmark/src/ui.rs index ed2875ab..7b251f0f 100644 --- a/benchmark/src/ui.rs +++ b/benchmark/src/ui.rs @@ -1,341 +1,362 @@ +use crate::generation::{Decode, Message, Prefill}; /// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs -use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers}; -use crossterm::{event, ExecutableCommand}; -use std::io; -use std::time::{Duration, Instant}; -use tokio::sync::mpsc::error::TryRecvError; -use tokio::sync::{broadcast, mpsc}; -use tokio::time::sleep; -use tui::backend::CrosstermBackend; +use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use text_generation_client::ClientError; +use tokio::sync::mpsc; +use tui::backend::Backend; use tui::layout::{Constraint, Direction, Layout}; use tui::style::{Color, Modifier, Style}; use tui::text::{Span, Spans}; use tui::widgets::{ Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, }; -use tui::{symbols, Terminal}; -use text_generation_client::ClientError; -use crate::generation::Message; +use tui::{symbols, Frame}; + +struct Data { + prefill_latencies: Vec>, + prefill_throughputs: Vec>, + decode_latencies: Vec>, + decode_throughputs: Vec>, + prefill_batch_latency_throughput: Vec<(f64, f64)>, + decode_batch_latency_throughput: Vec<(f64, f64)>, +} + +impl Data { + fn new(n_run: usize, n_batch: usize) -> Self { + let prefill_latencies: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + let prefill_throughputs: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + + let decode_latencies: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + let decode_throughputs: Vec> = + (0..n_batch).map(|_| Vec::with_capacity(n_run)).collect(); + + let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); + + let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch); + + Self { + prefill_latencies, + prefill_throughputs, + decode_latencies, + decode_throughputs, + prefill_batch_latency_throughput, + decode_batch_latency_throughput, + } + } + + fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) { + let latency = prefill.latency.as_millis() as f64; + self.prefill_latencies[batch_idx].push(latency); + self.prefill_throughputs[batch_idx].push(prefill.throughput); + } + + fn push_decode(&mut self, prefill: Decode, batch_idx: usize) { + let latency = prefill.latency.as_millis() as f64; + self.decode_latencies[batch_idx].push(latency); + self.decode_throughputs[batch_idx].push(prefill.throughput); + } + + fn end_batch(&mut self, batch_idx: usize) { + self.prefill_batch_latency_throughput.push(( + self.prefill_latencies[batch_idx].iter().sum::() + / self.prefill_latencies[batch_idx].len() as f64, + self.prefill_throughputs[batch_idx].iter().sum::() + / self.prefill_throughputs[batch_idx].len() as f64, + )); + self.decode_batch_latency_throughput.push(( + self.decode_latencies[batch_idx].iter().sum::() + / self.decode_latencies[batch_idx].len() as f64, + self.decode_throughputs[batch_idx].iter().sum::() + / self.decode_throughputs[batch_idx].len() as f64, + )); + } +} pub(crate) struct UI { - pub(crate) tokenizer_name: String, - pub(crate) sequence_length: u32, - pub(crate) decode_length: u32, - pub(crate) n_run: usize, - pub(crate) batch_size: Vec, - pub(crate) receiver: mpsc::Receiver>, - pub(crate) shutdown_sender: broadcast::Sender<()>, - pub(crate) _shutdown_guard_sender: mpsc::Sender<()>, + pub(crate) running: bool, + completed_runs: Vec, + completed_batch: usize, + current_batch: usize, + current_tab: usize, + is_error: bool, + data: Data, + tokenizer_name: String, + sequence_length: u32, + decode_length: u32, + n_run: usize, + batch_size: Vec, + receiver: mpsc::Receiver>, } impl UI { - pub async fn draw(mut self) -> Result<(), crossterm::ErrorKind> { - crossterm::terminal::enable_raw_mode()?; - io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; - io::stdout().execute(crossterm::cursor::Hide)?; + pub(crate) fn new( + receiver: mpsc::Receiver>, + tokenizer_name: String, + sequence_length: u32, + decode_length: u32, + n_run: usize, + batch_size: Vec, + ) -> Self { + let data = Data::new(n_run, batch_size.len()); + let current_tab = 0; - let mut current_tab_idx = 0; + let completed_runs: Vec = (0..batch_size.len()).map(|_| 0).collect(); + let completed_batch = 0; + let current_batch = 0; + let is_error = false; - let mut prefill_latencies: Vec> = (0..self.batch_size.len()) - .map(|_| Vec::with_capacity(self.n_run)) - .collect(); - let mut prefill_throughputs: Vec> = (0..self.batch_size.len()) - .map(|_| Vec::with_capacity(self.n_run)) - .collect(); + Self { + running: true, + completed_runs, + completed_batch, + current_batch, + current_tab, + is_error, + data, + tokenizer_name, + sequence_length, + decode_length, + n_run, + batch_size, + receiver, + } + } - let mut decode_latencies: Vec> = (0..self.batch_size.len()) - .map(|_| Vec::with_capacity(self.n_run)) - .collect(); - let mut decode_throughputs: Vec> = (0..self.batch_size.len()) - .map(|_| Vec::with_capacity(self.n_run)) - .collect(); - - let mut prefill_batch_latency_throughput: Vec<(f64, f64)> = - Vec::with_capacity(self.batch_size.len()); - - let mut decode_batch_latency_throughput: Vec<(f64, f64)> = - Vec::with_capacity(self.batch_size.len()); - - let mut completed_runs: Vec = (0..self.batch_size.len()).map(|_| 0).collect(); - let mut completed_batch = 0; - let mut current_batch_idx = 0; - let mut is_error = false; - - let mut terminal = { - let backend = CrosstermBackend::new(io::stdout()); - Terminal::new(backend)? - }; - - 'outer: loop { - let frame_start = Instant::now(); - loop { - match self.receiver.try_recv() { - Ok(message) => match message { - Ok(message) => { - match message { - Message::Prefill(step) => { - let latency = step.latency.as_millis() as f64; - prefill_latencies[current_batch_idx].push(latency); - prefill_throughputs[current_batch_idx].push(step.throughput); - } - Message::Decode(step) => { - let latency = step.latency.as_millis() as f64; - decode_latencies[current_batch_idx].push(latency); - decode_throughputs[current_batch_idx].push(step.throughput); - } - Message::Run(_) => { - completed_runs[current_batch_idx] += 1; - } - Message::EndBatch => { - prefill_batch_latency_throughput.push(( - prefill_latencies[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - prefill_throughputs[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - )); - decode_batch_latency_throughput.push(( - decode_latencies[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - decode_throughputs[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - )); - - completed_batch += 1; - if current_batch_idx < self.batch_size.len() - 1 { - current_batch_idx += 1; - } - } - Message::Warmup => {} - } - } - Err(_) => is_error = true - }, - Err(TryRecvError::Empty) => { - break; - } - Err(TryRecvError::Disconnected) => { - break; - } - } + pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) { + match key_event { + KeyEvent { + code: KeyCode::Right, + .. + } => { + self.current_tab = (self.current_tab + 1) % self.batch_size.len(); } - - let batch_progress = - (completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0); - let run_progress = - (completed_runs[current_batch_idx] as f64 / self.n_run as f64).clamp(0.0, 1.0); - - terminal.draw(|f| { - // Vertical layout - let row5 = Layout::default() - .direction(Direction::Vertical) - .constraints( - [ - Constraint::Length(1), - Constraint::Length(3), - Constraint::Length(3), - Constraint::Length(13), - Constraint::Min(10), - ] - .as_ref(), - ) - .split(f.size()); - - // Top row horizontal layout - let top = Layout::default() - .direction(Direction::Horizontal) - .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) - .split(row5[2]); - - // Mid row horizontal layout - let mid = Layout::default() - .direction(Direction::Horizontal) - .constraints( - [ - Constraint::Percentage(20), - Constraint::Percentage(30), - Constraint::Percentage(20), - Constraint::Percentage(30), - ] - .as_ref(), - ) - .split(row5[3]); - - // Left mid row vertical layout - let prefill_text = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) - .split(mid[0]); - - // Right mid row vertical layout - let decode_text = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) - .split(mid[2]); - - // Bottom row horizontal layout - let bottom = Layout::default() - .direction(Direction::Horizontal) - .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) - .split(row5[4]); - - // Title - let title = Block::default().borders(Borders::NONE).title(format!( - "Model: {} | Sequence Length: {} | Decode Length: {}", - self.tokenizer_name, self.sequence_length, self.decode_length - )).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White)); - f.render_widget(title, row5[0]); - - // Batch tabs - let titles = self - .batch_size - .iter() - .map(|b| { - Spans::from(vec![Span::styled( - format!("Batch: {b}"), - Style::default().fg(Color::White), - )]) - }) - .collect(); - let tabs = Tabs::new(titles) - .block(Block::default().borders(Borders::ALL).title("Tabs")) - .select(current_tab_idx) - .style(Style::default().fg(Color::LightCyan)) - .highlight_style( - Style::default() - .add_modifier(Modifier::BOLD) - .bg(Color::Black), - ); - f.render_widget(tabs, row5[1]); - - // Total progress bar - let batch_gauge = progress_gauge( - "Total Progress", - format!("{} / {}", completed_batch, self.batch_size.len()), - batch_progress, - Color::LightGreen, - ); - f.render_widget(batch_gauge, top[0]); - - // Batch progress Bar - let run_gauge = progress_gauge( - "Batch Progress", - format!("{} / {}", completed_runs[current_batch_idx], self.n_run), - run_progress, - Color::LightBlue, - ); - f.render_widget(run_gauge, top[1]); - - // Prefill text infos - let (prefill_latency_statics, prefill_throughput_statics) = text_info( - &mut prefill_latencies[current_tab_idx], - &prefill_throughputs[current_tab_idx], - "Prefill", - ); - f.render_widget(prefill_latency_statics, prefill_text[0]); - f.render_widget(prefill_throughput_statics, prefill_text[1]); - - // Prefill latency histogram - let histo_width = 7; - let bins = if mid[1].width < 2 { - 0 + KeyEvent { + code: KeyCode::Left, + .. + } => { + if self.current_tab > 0 { + self.current_tab -= 1; } else { - (mid[1].width as usize - 2) / (histo_width + 1) - } - .max(2); - - let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins); - let histo_data_str: Vec<(&str, u64)> = - histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); - let prefill_histogram = - latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16); - f.render_widget(prefill_histogram, mid[1]); - - // Decode text info - let (decode_latency_statics, decode_throughput_statics) = text_info( - &mut decode_latencies[current_tab_idx], - &decode_throughputs[current_tab_idx], - "Decode", - ); - f.render_widget(decode_latency_statics, decode_text[0]); - f.render_widget(decode_throughput_statics, decode_text[1]); - - // Decode latency histogram - let histo_data = latency_histogram_data(&decode_latencies[current_tab_idx], bins); - let histo_data_str: Vec<(&str, u64)> = - histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); - let decode_histogram = - latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16); - f.render_widget(decode_histogram, mid[3]); - - // Prefill latency/throughput chart - let prefill_latency_throughput_chart = latency_throughput_chart( - &prefill_batch_latency_throughput, - &self.batch_size, - "Prefill", - ); - f.render_widget(prefill_latency_throughput_chart, bottom[0]); - - // Decode latency/throughput chart - let decode_latency_throughput_chart = latency_throughput_chart( - &decode_batch_latency_throughput, - &self.batch_size, - "Decode", - ); - f.render_widget(decode_latency_throughput_chart, bottom[1]); - })?; - - // Quit on q or CTRL+c - - while event::poll(Duration::from_millis(100))? { - if let Event::Key(key) = event::read()? { - match key { - KeyEvent { - code: KeyCode::Right, - .. - } => { - current_tab_idx = (current_tab_idx + 1) % self.batch_size.len(); - } - KeyEvent { - code: KeyCode::Left, - .. - } => { - if current_tab_idx > 0 { - current_tab_idx -= 1; - } else { - current_tab_idx = self.batch_size.len() - 1; - } - } - KeyEvent { - code: KeyCode::Char('q'), - .. - } - | KeyEvent { - code: KeyCode::Char('c'), - modifiers: KeyModifiers::CONTROL, - .. - } => { - break 'outer; - } - _ => (), - } + self.current_tab = self.batch_size.len() - 1; } } + KeyEvent { + code: KeyCode::Char('q'), + .. + } + | KeyEvent { + code: KeyCode::Char('c'), + modifiers: KeyModifiers::CONTROL, + .. + } => { + self.running = false; + } + _ => (), + } + } - // Frame budget - let per_frame = Duration::from_secs(1) / 30 as u32; - let elapsed = frame_start.elapsed(); - if per_frame > elapsed { - sleep(per_frame - elapsed).await; + pub(crate) fn tick(&mut self) { + while let Ok(message) = self.receiver.try_recv() { + match message { + Ok(message) => match message { + Message::Prefill(step) => self.data.push_prefill(step, self.current_batch), + Message::Decode(step) => self.data.push_decode(step, self.current_batch), + Message::Run(_) => { + self.completed_runs[self.current_batch] += 1; + } + Message::EndBatch => { + self.data.end_batch(self.current_batch); + + self.completed_batch += 1; + if self.current_batch < self.batch_size.len() - 1 { + self.current_batch += 1; + } + } + Message::Warmup => {} + }, + Err(_) => self.is_error = true, } } + } - // Revert terminal to original view - io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; - crossterm::terminal::disable_raw_mode()?; - io::stdout().execute(crossterm::cursor::Show)?; + pub fn render(&mut self, f: &mut Frame<'_, B>) { + let batch_progress = + (self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0); + let run_progress = + (self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0); - let _ = self.shutdown_sender.send(()); - Ok(()) + // Vertical layout + let row5 = Layout::default() + .direction(Direction::Vertical) + .constraints( + [ + Constraint::Length(1), + Constraint::Length(3), + Constraint::Length(3), + Constraint::Length(13), + Constraint::Min(10), + ] + .as_ref(), + ) + .split(f.size()); + + // Top row horizontal layout + let top = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) + .split(row5[2]); + + // Mid row horizontal layout + let mid = Layout::default() + .direction(Direction::Horizontal) + .constraints( + [ + Constraint::Percentage(20), + Constraint::Percentage(30), + Constraint::Percentage(20), + Constraint::Percentage(30), + ] + .as_ref(), + ) + .split(row5[3]); + + // Left mid row vertical layout + let prefill_text = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) + .split(mid[0]); + + // Right mid row vertical layout + let decode_text = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref()) + .split(mid[2]); + + // Bottom row horizontal layout + let bottom = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) + .split(row5[4]); + + // Title + let title = Block::default() + .borders(Borders::NONE) + .title(format!( + "Model: {} | Sequence Length: {} | Decode Length: {}", + self.tokenizer_name, self.sequence_length, self.decode_length + )) + .style( + Style::default() + .add_modifier(Modifier::BOLD) + .fg(Color::White), + ); + f.render_widget(title, row5[0]); + + // Batch tabs + let titles = self + .batch_size + .iter() + .map(|b| { + Spans::from(vec![Span::styled( + format!("Batch: {b}"), + Style::default().fg(Color::White), + )]) + }) + .collect(); + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::ALL).title("Tabs")) + .select(self.current_tab) + .style(Style::default().fg(Color::LightCyan)) + .highlight_style( + Style::default() + .add_modifier(Modifier::BOLD) + .bg(Color::Black), + ); + f.render_widget(tabs, row5[1]); + + // Total progress bar + let batch_gauge = progress_gauge( + "Total Progress", + format!("{} / {}", self.completed_batch, self.batch_size.len()), + batch_progress, + Color::LightGreen, + ); + f.render_widget(batch_gauge, top[0]); + + // Batch progress Bar + let run_gauge = progress_gauge( + "Batch Progress", + format!( + "{} / {}", + self.completed_runs[self.current_batch], self.n_run + ), + run_progress, + Color::LightBlue, + ); + f.render_widget(run_gauge, top[1]); + + // Prefill text infos + let (prefill_latency_statics, prefill_throughput_statics) = text_info( + &mut self.data.prefill_latencies[self.current_tab], + &self.data.prefill_throughputs[self.current_tab], + "Prefill", + ); + f.render_widget(prefill_latency_statics, prefill_text[0]); + f.render_widget(prefill_throughput_statics, prefill_text[1]); + + // Prefill latency histogram + let histo_width = 7; + let bins = if mid[1].width < 2 { + 0 + } else { + (mid[1].width as usize - 2) / (histo_width + 1) + } + .max(2); + + let histo_data = + latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins); + let histo_data_str: Vec<(&str, u64)> = + histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); + let prefill_histogram = + latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16); + f.render_widget(prefill_histogram, mid[1]); + + // Decode text info + let (decode_latency_statics, decode_throughput_statics) = text_info( + &mut self.data.decode_latencies[self.current_tab], + &self.data.decode_throughputs[self.current_tab], + "Decode", + ); + f.render_widget(decode_latency_statics, decode_text[0]); + f.render_widget(decode_throughput_statics, decode_text[1]); + + // Decode latency histogram + let histo_data = + latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins); + let histo_data_str: Vec<(&str, u64)> = + histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect(); + let decode_histogram = + latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16); + f.render_widget(decode_histogram, mid[3]); + + // Prefill latency/throughput chart + let prefill_latency_throughput_chart = latency_throughput_chart( + &self.data.prefill_batch_latency_throughput, + &self.batch_size, + "Prefill", + ); + f.render_widget(prefill_latency_throughput_chart, bottom[0]); + + // Decode latency/throughput chart + let decode_latency_throughput_chart = latency_throughput_chart( + &self.data.decode_batch_latency_throughput, + &self.batch_size, + "Decode", + ); + f.render_widget(decode_latency_throughput_chart, bottom[1]); } }