add helper

This commit is contained in:
OlivierDehaene 2023-03-30 11:56:39 +02:00
parent 271f045825
commit 17a75c8845

View File

@ -4,7 +4,7 @@ use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use text_generation_client::ClientError; use text_generation_client::ClientError;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tui::backend::Backend; use tui::backend::Backend;
use tui::layout::{Constraint, Direction, Layout}; use tui::layout::{Alignment, Constraint, Direction, Layout};
use tui::style::{Color, Modifier, Style}; use tui::style::{Color, Modifier, Style};
use tui::text::{Span, Spans}; use tui::text::{Span, Spans};
use tui::widgets::{ use tui::widgets::{
@ -81,6 +81,7 @@ pub(crate) struct UI {
completed_batch: usize, completed_batch: usize,
current_batch: usize, current_batch: usize,
current_tab: usize, current_tab: usize,
zoom: bool,
is_error: bool, is_error: bool,
data: Data, data: Data,
tokenizer_name: String, tokenizer_name: String,
@ -114,6 +115,7 @@ impl UI {
completed_batch, completed_batch,
current_batch, current_batch,
current_tab, current_tab,
zoom: false,
is_error, is_error,
data, data,
tokenizer_name, tokenizer_name,
@ -130,6 +132,10 @@ impl UI {
KeyEvent { KeyEvent {
code: KeyCode::Right, code: KeyCode::Right,
.. ..
} |
KeyEvent {
code: KeyCode::Tab,
..
} => { } => {
self.current_tab = (self.current_tab + 1) % self.batch_size.len(); self.current_tab = (self.current_tab + 1) % self.batch_size.len();
} }
@ -143,6 +149,19 @@ impl UI {
self.current_tab = self.batch_size.len() - 1; self.current_tab = self.batch_size.len() - 1;
} }
} }
KeyEvent {
code: KeyCode::Char('+'),
..
} => {
self.zoom = true;
}
KeyEvent {
code: KeyCode::Char('-'),
..
} => {
self.zoom = false;
}
KeyEvent { KeyEvent {
code: KeyCode::Char('q'), code: KeyCode::Char('q'),
.. ..
@ -255,6 +274,19 @@ impl UI {
); );
f.render_widget(title, row5[0]); f.render_widget(title, row5[0]);
// Helper
let helper = Block::default()
.borders(Borders::NONE)
.title(format!(
"<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom"
))
.title_alignment(Alignment::Right)
.style(
Style::default()
.fg(Color::White),
);
f.render_widget(helper, row5[0]);
// Batch tabs // Batch tabs
let titles = self let titles = self
.batch_size .batch_size
@ -278,15 +310,25 @@ impl UI {
f.render_widget(tabs, row5[1]); f.render_widget(tabs, row5[1]);
// Total progress bar // Total progress bar
let color = if self.is_error {
Color::Red
} else {
Color::LightGreen
};
let batch_gauge = progress_gauge( let batch_gauge = progress_gauge(
"Total Progress", "Total Progress",
format!("{} / {}", self.completed_batch, self.batch_size.len()), format!("{} / {}", self.completed_batch, self.batch_size.len()),
batch_progress, batch_progress,
Color::LightGreen, color,
); );
f.render_widget(batch_gauge, top[0]); f.render_widget(batch_gauge, top[0]);
// Batch progress Bar // Batch progress Bar
let color = if self.is_error {
Color::Red
} else {
Color::LightBlue
};
let run_gauge = progress_gauge( let run_gauge = progress_gauge(
"Batch Progress", "Batch Progress",
format!( format!(
@ -294,7 +336,7 @@ impl UI {
self.completed_runs[self.current_batch], self.n_run self.completed_runs[self.current_batch], self.n_run
), ),
run_progress, run_progress,
Color::LightBlue, color,
); );
f.render_widget(run_gauge, top[1]); f.render_widget(run_gauge, top[1]);
@ -346,6 +388,7 @@ impl UI {
let prefill_latency_throughput_chart = latency_throughput_chart( let prefill_latency_throughput_chart = latency_throughput_chart(
&self.data.prefill_batch_latency_throughput, &self.data.prefill_batch_latency_throughput,
&self.batch_size, &self.batch_size,
self.zoom,
"Prefill", "Prefill",
); );
f.render_widget(prefill_latency_throughput_chart, bottom[0]); f.render_widget(prefill_latency_throughput_chart, bottom[0]);
@ -354,6 +397,7 @@ impl UI {
let decode_latency_throughput_chart = latency_throughput_chart( let decode_latency_throughput_chart = latency_throughput_chart(
&self.data.decode_batch_latency_throughput, &self.data.decode_batch_latency_throughput,
&self.batch_size, &self.batch_size,
self.zoom,
"Decode", "Decode",
); );
f.render_widget(decode_latency_throughput_chart, bottom[1]); f.render_widget(decode_latency_throughput_chart, bottom[1]);
@ -461,6 +505,7 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
fn latency_throughput_chart<'a>( fn latency_throughput_chart<'a>(
latency_throughput: &'a Vec<(f64, f64)>, latency_throughput: &'a Vec<(f64, f64)>,
batch_sizes: &'a Vec<u32>, batch_sizes: &'a Vec<u32>,
zoom: bool,
name: &'static str, name: &'static str,
) -> Chart<'a> { ) -> Chart<'a> {
let latency_iter = latency_throughput.iter().map(|(l, _)| l); let latency_iter = latency_throughput.iter().map(|(l, _)| l);
@ -481,11 +526,19 @@ fn latency_throughput_chart<'a>(
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&std::f64::NAN);
let min_x = ((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0; let min_x = if zoom {
((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0
} else {
0.0
};
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0; let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
let step_x = (max_x - min_x) / 4.0; let step_x = (max_x - min_x) / 4.0;
let min_y = ((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0; let min_y = if zoom {
((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0
} else {
0.0
};
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0; let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
let step_y = (max_y - min_y) / 4.0; let step_y = (max_y - min_y) / 4.0;