improvements

This commit is contained in:
OlivierDehaene 2023-03-29 14:01:23 +02:00
parent 383619bd7f
commit 1c5d526943
3 changed files with 62 additions and 46 deletions

View File

@ -44,6 +44,7 @@ pub(crate) enum Message {
} }
pub async fn run( pub async fn run(
tokenizer_name: String,
tokenizer: Tokenizer, tokenizer: Tokenizer,
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
@ -57,6 +58,9 @@ pub async fn run(
tokio::spawn( tokio::spawn(
UI { UI {
tokenizer_name,
decode_length,
sequence_length,
n_run: n_runs, n_run: n_runs,
batch_size: batch_size.clone(), batch_size: batch_size.clone(),
receiver: ui_receiver, receiver: ui_receiver,
@ -68,22 +72,22 @@ pub async fn run(
let mut runs = Vec::with_capacity(batch_size.len() * n_runs); let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
let sequence = create_sequence(sequence_length, tokenizer); let sequence = create_sequence(sequence_length, tokenizer);
for _ in 0..warmups {
let (_, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
let _ = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
}
for b in batch_size { for b in batch_size {
for _ in 0..warmups {
let (_, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
let _ = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
}
for _ in 0..n_runs { for _ in 0..n_runs {
let (prefill, decode_batch) = tokio::select! { let (prefill, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?, res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,

View File

@ -15,13 +15,13 @@ struct Args {
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "1", long, env)] #[clap(default_value = "1", long, env)]
batch_size: Vec<u32>, batch_size: Vec<u32>,
#[clap(default_value = "12", long, env)]
sequence_length: u32,
#[clap(default_value = "10", long, env)] #[clap(default_value = "10", long, env)]
sequence_length: u32,
#[clap(default_value = "64", long, env)]
decode_length: u32, decode_length: u32,
#[clap(default_value = "10", long, env)] #[clap(default_value = "10", long, env)]
runs: usize, runs: usize,
#[clap(default_value = "0", long, env)] #[clap(default_value = "2", long, env)]
warmups: usize, warmups: usize,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
@ -74,12 +74,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.expect("Could not connect to server"); .expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted // Clear the cache; useful if the webserver rebooted
sharded_client sharded_client
.clear_cache() .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
text_generation_benchmark::run( text_generation_benchmark::run(
tokenizer_name,
tokenizer, tokenizer,
batch_size, batch_size,
sequence_length, sequence_length,

View File

@ -17,6 +17,9 @@ use tui::widgets::{
use tui::{symbols, Terminal}; use tui::{symbols, Terminal};
pub(crate) struct UI { pub(crate) struct UI {
pub(crate) tokenizer_name: String,
pub(crate) sequence_length: u32,
pub(crate) decode_length: u32,
pub(crate) n_run: usize, pub(crate) n_run: usize,
pub(crate) batch_size: Vec<u32>, pub(crate) batch_size: Vec<u32>,
pub(crate) receiver: mpsc::Receiver<Message>, pub(crate) receiver: mpsc::Receiver<Message>,
@ -117,10 +120,11 @@ impl UI {
terminal.draw(|f| { terminal.draw(|f| {
// Vertical layout // Vertical layout
let row4 = Layout::default() let row5 = Layout::default()
.direction(Direction::Vertical) .direction(Direction::Vertical)
.constraints( .constraints(
[ [
Constraint::Length(1),
Constraint::Length(3), Constraint::Length(3),
Constraint::Length(3), Constraint::Length(3),
Constraint::Length(13), Constraint::Length(13),
@ -134,7 +138,7 @@ impl UI {
let top = Layout::default() let top = Layout::default()
.direction(Direction::Horizontal) .direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row4[0]); .split(row5[2]);
// Mid row horizontal layout // Mid row horizontal layout
let mid = Layout::default() let mid = Layout::default()
@ -148,7 +152,7 @@ impl UI {
] ]
.as_ref(), .as_ref(),
) )
.split(row4[2]); .split(row5[3]);
// Left mid row vertical layout // Left mid row vertical layout
let prefill_text = Layout::default() let prefill_text = Layout::default()
@ -166,7 +170,36 @@ impl UI {
let bottom = Layout::default() let bottom = Layout::default()
.direction(Direction::Horizontal) .direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()) .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row4[3]); .split(row5[4]);
// Title
let title = Block::default().borders(Borders::NONE).title(format!(
"Model: {} | Sequence Length: {} | Decode Length: {}",
self.tokenizer_name, self.sequence_length, self.decode_length
)).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White));
f.render_widget(title, row5[0]);
// Batch tabs
let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![Span::styled(
format!("Batch: {b}"),
Style::default().fg(Color::White),
)])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(current_tab_idx)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row5[1]);
// Total progress bar // Total progress bar
let batch_gauge = progress_gauge( let batch_gauge = progress_gauge(
@ -186,28 +219,6 @@ impl UI {
); );
f.render_widget(run_gauge, top[1]); f.render_widget(run_gauge, top[1]);
// Batch tabs
let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![
Span::raw(format!("Batch: {b}")), // Span::styled(first, Style::default().fg(Color::Yellow)),
// Span::styled(rest, Style::default().fg(Color::Green)),
])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(current_tab_idx)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row4[1]);
// Prefill text infos // Prefill text infos
let (prefill_latency_statics, prefill_throughput_statics) = text_info( let (prefill_latency_statics, prefill_throughput_statics) = text_info(
&mut prefill_latencies[current_tab_idx], &mut prefill_latencies[current_tab_idx],
@ -384,7 +395,7 @@ fn latency_histogram<'a>(
.block( .block(
Block::default() Block::default()
.title(format!("{name} latency histogram")) .title(format!("{name} latency histogram"))
.style(Style::default().fg(Color::Yellow).bg(Color::Reset)) .style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
.borders(Borders::ALL), .borders(Borders::ALL),
) )
.data(histo_data_str.as_slice()) .data(histo_data_str.as_slice())