improvements

This commit is contained in:
OlivierDehaene 2023-03-29 14:01:23 +02:00
parent 383619bd7f
commit 1c5d526943
3 changed files with 62 additions and 46 deletions

View File

@ -44,6 +44,7 @@ pub(crate) enum Message {
}
pub async fn run(
tokenizer_name: String,
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
@ -57,6 +58,9 @@ pub async fn run(
tokio::spawn(
UI {
tokenizer_name,
decode_length,
sequence_length,
n_run: n_runs,
batch_size: batch_size.clone(),
receiver: ui_receiver,
@ -68,6 +72,7 @@ pub async fn run(
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
for _ in 0..warmups {
let (_, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
@ -83,7 +88,6 @@ pub async fn run(
};
}
for b in batch_size {
for _ in 0..n_runs {
let (prefill, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,

View File

@ -15,13 +15,13 @@ struct Args {
tokenizer_name: String,
#[clap(default_value = "1", long, env)]
batch_size: Vec<u32>,
#[clap(default_value = "12", long, env)]
sequence_length: u32,
#[clap(default_value = "10", long, env)]
sequence_length: u32,
#[clap(default_value = "64", long, env)]
decode_length: u32,
#[clap(default_value = "10", long, env)]
runs: usize,
#[clap(default_value = "0", long, env)]
#[clap(default_value = "2", long, env)]
warmups: usize,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
@ -74,12 +74,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache()
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
text_generation_benchmark::run(
tokenizer_name,
tokenizer,
batch_size,
sequence_length,

View File

@ -17,6 +17,9 @@ use tui::widgets::{
use tui::{symbols, Terminal};
pub(crate) struct UI {
pub(crate) tokenizer_name: String,
pub(crate) sequence_length: u32,
pub(crate) decode_length: u32,
pub(crate) n_run: usize,
pub(crate) batch_size: Vec<u32>,
pub(crate) receiver: mpsc::Receiver<Message>,
@ -117,10 +120,11 @@ impl UI {
terminal.draw(|f| {
// Vertical layout
let row4 = Layout::default()
let row5 = Layout::default()
.direction(Direction::Vertical)
.constraints(
[
Constraint::Length(1),
Constraint::Length(3),
Constraint::Length(3),
Constraint::Length(13),
@ -134,7 +138,7 @@ impl UI {
let top = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row4[0]);
.split(row5[2]);
// Mid row horizontal layout
let mid = Layout::default()
@ -148,7 +152,7 @@ impl UI {
]
.as_ref(),
)
.split(row4[2]);
.split(row5[3]);
// Left mid row vertical layout
let prefill_text = Layout::default()
@ -166,7 +170,36 @@ impl UI {
let bottom = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row4[3]);
.split(row5[4]);
// Title
let title = Block::default().borders(Borders::NONE).title(format!(
"Model: {} | Sequence Length: {} | Decode Length: {}",
self.tokenizer_name, self.sequence_length, self.decode_length
)).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White));
f.render_widget(title, row5[0]);
// Batch tabs
let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![Span::styled(
format!("Batch: {b}"),
Style::default().fg(Color::White),
)])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(current_tab_idx)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row5[1]);
// Total progress bar
let batch_gauge = progress_gauge(
@ -186,28 +219,6 @@ impl UI {
);
f.render_widget(run_gauge, top[1]);
// Batch tabs
let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![
Span::raw(format!("Batch: {b}")), // Span::styled(first, Style::default().fg(Color::Yellow)),
// Span::styled(rest, Style::default().fg(Color::Green)),
])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(current_tab_idx)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row4[1]);
// Prefill text infos
let (prefill_latency_statics, prefill_throughput_statics) = text_info(
&mut prefill_latencies[current_tab_idx],
@ -384,7 +395,7 @@ fn latency_histogram<'a>(
.block(
Block::default()
.title(format!("{name} latency histogram"))
.style(Style::default().fg(Color::Yellow).bg(Color::Reset))
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
.borders(Borders::ALL),
)
.data(histo_data_str.as_slice())