This commit is contained in:
OlivierDehaene 2023-03-29 11:58:19 +02:00
parent 681744b982
commit 383619bd7f
10 changed files with 626 additions and 215 deletions

1
Cargo.lock generated
View File

@ -2324,6 +2324,7 @@ dependencies = [
"average", "average",
"clap 4.1.8", "clap 4.1.8",
"crossterm", "crossterm",
"float-ord",
"ratatui", "ratatui",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -16,6 +16,7 @@ path = "src/main.rs"
average = "0.13" average = "0.13"
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
crossterm = "0.26" crossterm = "0.26"
float-ord = "0.3.2"
serde = {version = "1.0.142", features = ["derive"]} serde = {version = "1.0.142", features = ["derive"]}
serde_json = "1.0" serde_json = "1.0"
text-generation-client = { path = "../router/client" } text-generation-client = { path = "../router/client" }

View File

@ -1,27 +1,46 @@
mod ui; extern crate core;
mod ui;
mod utils;
use std::time::{Duration, Instant};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::time;
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
use tokio::sync::{mpsc, broadcast};
use crate::ui::UI; use crate::ui::UI;
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum Step { pub(crate) struct Prefill {
Prefill, batch_size: u32,
Decode, sequence_length: u32,
latency: Duration,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Run { pub(crate) struct Decode {
step: Step,
batch_size: u32, batch_size: u32,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
time: Duration, latency: Duration,
}
#[derive(Debug)]
pub(crate) struct Run {
prefill: Prefill,
decode: Decode,
}
#[derive(Debug)]
pub(crate) enum Message {
Prefill(Prefill),
Decode(Decode),
IncreaseRun,
IncreaseBatch,
} }
pub async fn run( pub async fn run(
@ -30,68 +49,89 @@ pub async fn run(
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
n_runs: usize, n_runs: usize,
warmups: usize,
mut client: ShardedClient, mut client: ShardedClient,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let (run_sender, run_receiver) = mpsc::channel(8); let (ui_sender, ui_receiver) = mpsc::channel(8);
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1); let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
tokio::spawn( tokio::spawn(
UI { UI {
n_run: n_runs, n_run: n_runs,
n_batch: batch_size.len(), batch_size: batch_size.clone(),
n_batch_done: 0, receiver: ui_receiver,
run_receiver,
shutdown_sender, shutdown_sender,
}.draw() }
.draw(),
); );
let mut runs = Vec::with_capacity(batch_size.len() * n_runs); let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
let sequence = create_sequence(sequence_length, tokenizer); let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size { for _ in 0..warmups {
for _ in 0..n_runs { let (_, decode_batch) = tokio::select! {
let (run, decode_batch) = tokio::select! { res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => { _ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(()); return Ok(());
} }
}; };
run_sender.send(run.clone()).await.unwrap(); let _ = tokio::select! {
runs.push(run); res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
let run = tokio::select! { return Ok(());
}
};
}
for b in batch_size {
for _ in 0..n_runs {
let (prefill, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
ui_sender
.send(Message::Prefill(prefill.clone()))
.await
.unwrap();
let decode = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?, res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => { _ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(()); return Ok(());
} }
}; };
run_sender.send(run.clone()).await.unwrap(); ui_sender
runs.push(run); .send(Message::Decode(decode.clone()))
.await
.unwrap();
runs.push(Run { prefill, decode });
ui_sender.send(Message::IncreaseRun).await.unwrap();
} }
ui_sender.send(Message::IncreaseBatch).await.unwrap();
} }
// Shutdown UI by dropping run sender triggering the UI to exit its rendering loop // Signal the UI that we are done
drop(run_sender); drop(ui_sender);
// Wait for UI to shutdown // Wait for UI shutdown signal
let _ = shutdown_receiver.recv().await; let _ = shutdown_receiver.recv().await;
Ok(()) Ok(())
} }
async fn run_prefill( async fn run_prefill(
sequence: String, sequence: String,
sequence_length: u32, sequence_length: u32,
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
client: &mut ShardedClient) -> Result<(Run, Batch), ClientError> { client: &mut ShardedClient,
let requests = (0..batch_size).map(|id| { ) -> Result<(Prefill, Batch), ClientError> {
Request { let requests = (0..batch_size)
.map(|id| Request {
id: id.into(), id: id.into(),
inputs: sequence.clone(), inputs: sequence.clone(),
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
@ -109,8 +149,8 @@ async fn run_prefill(
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, ignore_eos_token: true,
}), }),
} })
}).collect(); .collect();
let batch = Batch { let batch = Batch {
id: 0, id: 0,
@ -124,18 +164,20 @@ async fn run_prefill(
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let run = Run { let step = Prefill {
step: Step::Prefill,
batch_size, batch_size,
sequence_length, sequence_length,
decode_length: 1, latency: elasped,
time: elasped,
}; };
Ok((run, decode_batch)) Ok((step, decode_batch))
} }
async fn run_decode(batch: Batch, sequence_length: u32, client: &mut ShardedClient) -> Result<Run, ClientError> { async fn run_decode(
batch: Batch,
sequence_length: u32,
client: &mut ShardedClient,
) -> Result<Decode, ClientError> {
let batch_size = batch.size; let batch_size = batch.size;
let mut decode_length = 0; let mut decode_length = 0;
let start_time = Instant::now(); let start_time = Instant::now();
@ -147,24 +189,26 @@ async fn run_decode(batch: Batch, sequence_length: u32, client: &mut ShardedClie
decode_length += 1; decode_length += 1;
} }
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
let run = Run { let step = Decode {
step: Step::Decode,
batch_size, batch_size,
sequence_length, sequence_length,
decode_length, decode_length,
time: elapsed, latency: elapsed,
}; };
Ok(run) Ok(step)
} }
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len(); let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length // Repeat lorem ipsum to cover sequence length
let string_sequence = LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len()); let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence // Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap(); let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length // Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode // Decode
tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap() tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
} }

View File

@ -1,11 +1,11 @@
use clap::Parser;
/// Text Generation Inference benchmarking tool /// Text Generation Inference benchmarking tool
use std::path::Path; use std::path::Path;
use clap::Parser; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use text_generation_client::ShardedClient; use tracing_subscriber::EnvFilter;
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -15,13 +15,15 @@ struct Args {
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "1", long, env)] #[clap(default_value = "1", long, env)]
batch_size: Vec<u32>, batch_size: Vec<u32>,
#[clap(default_value = "128", long, env)] #[clap(default_value = "12", long, env)]
sequence_length: u32, sequence_length: u32,
#[clap(default_value = "100", long, env)] #[clap(default_value = "10", long, env)]
decode_length: u32, decode_length: u32,
#[clap(default_value = "2", long, env)] #[clap(default_value = "10", long, env)]
runs: usize, runs: usize,
#[clap(default_value = "/tmp/text-generation-0", long, env)] #[clap(default_value = "0", long, env)]
warmups: usize,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
} }
@ -35,22 +37,29 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
sequence_length, sequence_length,
decode_length, decode_length,
runs, runs,
warmups,
master_shard_uds_path, master_shard_uds_path,
} = args; } = args;
init_logging();
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let tokenizer = let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{ {
// Load local tokenizer // Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).expect("unable to load local tokenizer") tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else { } else {
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
Tokenizer::from_pretrained(tokenizer_name.clone(), None).expect("unable to load hub tokenizer") tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
}; };
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
@ -58,8 +67,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
init_logging();
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
tracing::info!("Connect to model server"); tracing::info!("Connect to model server");
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
@ -78,8 +85,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
sequence_length, sequence_length,
decode_length, decode_length,
runs, runs,
warmups,
sharded_client, sharded_client,
).await.unwrap(); )
.await
.unwrap();
}); });
Ok(()) Ok(())
} }
@ -91,7 +101,6 @@ fn init_logging() {
.with_file(true) .with_file(true)
.with_line_number(true); .with_line_number(true);
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let env_filter = let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

View File

@ -1,42 +1,59 @@
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs /// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
use crate::Message;
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers}; use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
use crossterm::ExecutableCommand; use crossterm::{event, ExecutableCommand};
use std::collections::BTreeMap;
use std::io; use std::io;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::{broadcast, mpsc};
use tokio::time::sleep; use tokio::time::sleep;
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;
use tui::layout::{Constraint, Direction, Layout}; use tui::layout::{Constraint, Direction, Layout};
use tui::style::{Color, Style}; use tui::style::{Color, Modifier, Style};
use tui::text::{Span, Spans}; use tui::text::{Span, Spans};
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph}; use tui::widgets::{
use tui::Terminal; Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
use tokio::sync::{mpsc, broadcast}; };
use crate::{Run, Step}; use tui::{symbols, Terminal};
pub(crate) struct UI { pub(crate) struct UI {
pub(crate) n_run: usize, pub(crate) n_run: usize,
pub(crate) n_batch: usize, pub(crate) batch_size: Vec<u32>,
pub(crate) n_batch_done: usize, pub(crate) receiver: mpsc::Receiver<Message>,
pub(crate) run_receiver: mpsc::Receiver<Run>,
pub(crate) shutdown_sender: broadcast::Sender<()>, pub(crate) shutdown_sender: broadcast::Sender<()>,
} }
impl UI { impl UI {
pub async fn draw( pub async fn draw(mut self) -> Result<(), crossterm::ErrorKind> {
mut self
) -> Result<(), crossterm::ErrorKind> {
crossterm::terminal::enable_raw_mode()?; crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?; io::stdout().execute(crossterm::cursor::Hide)?;
let mut prefill_latency = Vec::new(); let mut current_tab_idx = 0;
let mut prefill_throughput = Vec::new();
let mut decode_latency = Vec::new(); let mut prefill_latencies: Vec<Vec<f64>> = (0..self.batch_size.len())
let mut decode_throughput = Vec::new(); .map(|_| Vec::with_capacity(self.n_run))
let mut runs = Vec::new(); .collect();
let mut prefill_throughputs: Vec<Vec<f64>> = (0..self.batch_size.len())
.map(|_| Vec::with_capacity(self.n_run))
.collect();
let mut decode_latencies: Vec<Vec<f64>> = (0..self.batch_size.len())
.map(|_| Vec::with_capacity(self.n_run))
.collect();
let mut decode_throughputs: Vec<Vec<f64>> = (0..self.batch_size.len())
.map(|_| Vec::with_capacity(self.n_run))
.collect();
let mut prefill_batch_latency_throughput: Vec<(f64, f64)> =
Vec::with_capacity(self.batch_size.len());
let mut decode_batch_latency_throughput: Vec<(f64, f64)> =
Vec::with_capacity(self.batch_size.len());
let mut completed_runs: Vec<usize> = (0..self.batch_size.len()).map(|_| 0).collect();
let mut completed_batch = 0;
let mut current_batch_idx = 0;
let mut terminal = { let mut terminal = {
let backend = CrosstermBackend::new(io::stdout()); let backend = CrosstermBackend::new(io::stdout());
@ -46,153 +63,248 @@ impl UI {
'outer: loop { 'outer: loop {
let frame_start = Instant::now(); let frame_start = Instant::now();
loop { loop {
match self.run_receiver.try_recv() { match self.receiver.try_recv() {
Ok(run) => { Ok(message) => match message {
match run.step { Message::Prefill(step) => {
Step::Prefill => { let latency = step.latency.as_millis() as f64;
let latency = run.time.as_millis() as f64; let throughput = step.batch_size as f64 / step.latency.as_secs_f64();
let throughput = run.batch_size as f64 / run.time.as_secs_f64(); prefill_latencies[current_batch_idx].push(latency);
prefill_latency.push(latency); prefill_throughputs[current_batch_idx].push(throughput);
prefill_throughput.push(throughput);
} }
Step::Decode => { Message::Decode(step) => {
let latency = run.time.as_millis() as f64; let latency = step.latency.as_millis() as f64;
let throughput = (run.batch_size * run.decode_length) as f64 / run.time.as_secs_f64(); let throughput = (step.batch_size * step.decode_length) as f64
decode_latency.push(latency); / step.latency.as_secs_f64();
decode_throughput.push(throughput); decode_latencies[current_batch_idx].push(latency);
decode_throughputs[current_batch_idx].push(throughput);
}
Message::IncreaseRun => {
completed_runs[current_batch_idx] += 1;
}
Message::IncreaseBatch => {
prefill_batch_latency_throughput.push((
prefill_latencies[current_batch_idx].iter().sum::<f64>()
/ completed_runs[current_batch_idx] as f64,
prefill_throughputs[current_batch_idx].iter().sum::<f64>()
/ completed_runs[current_batch_idx] as f64,
));
decode_batch_latency_throughput.push((
decode_latencies[current_batch_idx].iter().sum::<f64>()
/ completed_runs[current_batch_idx] as f64,
decode_throughputs[current_batch_idx].iter().sum::<f64>()
/ completed_runs[current_batch_idx] as f64,
));
completed_batch += 1;
if current_batch_idx < self.batch_size.len() - 1 {
current_batch_idx += 1;
} }
} }
runs.push(run); },
}
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => {
break; break;
} }
Err(TryRecvError::Disconnected) => { Err(TryRecvError::Disconnected) => {
// Application ends. break;
break 'outer;
} }
} }
} }
let batch_progress = (self.n_batch_done as f64 / self.n_batch as f64).clamp(0.0, 1.0); let batch_progress =
let run_progress = (runs.len() as f64 / self.n_run as f64).clamp(0.0, 1.0); (completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
let run_progress =
(completed_runs[current_batch_idx] as f64 / self.n_run as f64).clamp(0.0, 1.0);
terminal.draw(|f| { terminal.draw(|f| {
let row3 = Layout::default() // Vertical layout
let row4 = Layout::default()
.direction(Direction::Vertical) .direction(Direction::Vertical)
.constraints( .constraints(
[ [
Constraint::Length(3), Constraint::Length(3),
Constraint::Length(10), Constraint::Length(3),
Constraint::Percentage(45), Constraint::Length(13),
].as_ref(), Constraint::Min(10),
).split(f.size()); ]
.as_ref(),
)
.split(f.size());
// Top row horizontal layout
let top = Layout::default() let top = Layout::default()
.direction(Direction::Horizontal) .direction(Direction::Horizontal)
.constraints([ .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
Constraint::Percentage(50), .split(row4[0]);
Constraint::Percentage(50),
].as_ref()).split(row3[0]);
// Mid row horizontal layout
let mid = Layout::default() let mid = Layout::default()
.direction(Direction::Horizontal) .direction(Direction::Horizontal)
.constraints([ .constraints(
[
Constraint::Percentage(20), Constraint::Percentage(20),
Constraint::Percentage(30), Constraint::Percentage(30),
Constraint::Percentage(20), Constraint::Percentage(20),
Constraint::Percentage(30), Constraint::Percentage(30),
].as_ref()).split(row3[1]); ]
.as_ref(),
)
.split(row4[2]);
// Left mid row vertical layout
let prefill_text = Layout::default() let prefill_text = Layout::default()
.direction(Direction::Vertical) .direction(Direction::Vertical)
.constraints([ .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
Constraint::Length(5), .split(mid[0]);
Constraint::Length(5),
].as_ref()).split(mid[0]);
// Right mid row vertical layout
let decode_text = Layout::default() let decode_text = Layout::default()
.direction(Direction::Vertical) .direction(Direction::Vertical)
.constraints([ .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
Constraint::Length(5), .split(mid[2]);
Constraint::Length(5),
].as_ref()).split(mid[2]);
// Bottom row horizontal layout
let bottom = Layout::default() let bottom = Layout::default()
.direction(Direction::Horizontal) .direction(Direction::Horizontal)
.constraints([ .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
Constraint::Percentage(25), .split(row4[3]);
Constraint::Percentage(25),
Constraint::Percentage(25),
Constraint::Percentage(25),
].as_ref()).split(row3[2]);
let batch_gauge = Gauge::default() // Total progress bar
.block(Block::default().title("Total Progress").borders(Borders::ALL)) let batch_gauge = progress_gauge(
.gauge_style(Style::default().fg(Color::White)) "Total Progress",
.label(Span::raw(format!("{} / {}", self.n_batch_done, self.n_batch))) format!("{} / {}", completed_batch, self.batch_size.len()),
.ratio(batch_progress); batch_progress,
Color::LightGreen,
);
f.render_widget(batch_gauge, top[0]); f.render_widget(batch_gauge, top[0]);
let run_gauge = Gauge::default() // Batch progress Bar
.block(Block::default().title("Batch Progress").borders(Borders::ALL)) let run_gauge = progress_gauge(
.gauge_style(Style::default().fg(Color::White)) "Batch Progress",
.label(Span::raw(format!("{} / {}", runs.len(), self.n_run))) format!("{} / {}", completed_runs[current_batch_idx], self.n_run),
.ratio(run_progress); run_progress,
Color::LightBlue,
);
f.render_widget(run_gauge, top[1]); f.render_widget(run_gauge, top[1]);
let prefill_latency_texts = statis_spans(&prefill_latency, "ms", false); // Batch tabs
let prefill_throughput_texts = statis_spans(&prefill_throughput, "tokens/secs", false); let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![
Span::raw(format!("Batch: {b}")), // Span::styled(first, Style::default().fg(Color::Yellow)),
// Span::styled(rest, Style::default().fg(Color::Green)),
])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(current_tab_idx)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row4[1]);
let prefill_latency_statics = Paragraph::new(prefill_latency_texts).block( // Prefill text infos
Block::default() let (prefill_latency_statics, prefill_throughput_statics) = text_info(
.title(Span::raw("Prefill Latency")) &mut prefill_latencies[current_tab_idx],
.borders(Borders::ALL), &prefill_throughputs[current_tab_idx],
"Prefill",
); );
f.render_widget(prefill_latency_statics, prefill_text[0]); f.render_widget(prefill_latency_statics, prefill_text[0]);
let prefill_throughput_statics = Paragraph::new(prefill_throughput_texts).block(
Block::default()
.title(Span::raw("Prefill Throughput"))
.borders(Borders::ALL),
);
f.render_widget(prefill_throughput_statics, prefill_text[1]); f.render_widget(prefill_throughput_statics, prefill_text[1]);
let decode_latency_texts = statis_spans(&decode_latency, "ms", false); // Prefill latency histogram
let decode_throughput_texts = statis_spans(&decode_throughput, "tokens/secs", false); let histo_width = 7;
let bins = if mid[1].width < 2 {
0
} else {
(mid[1].width as usize - 2) / (histo_width + 1)
}
.max(2);
let decode_latency_statics = Paragraph::new(decode_latency_texts).block( let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins);
Block::default() let histo_data_str: Vec<(&str, u64)> =
.title(Span::raw("Decode Latency")) histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
.borders(Borders::ALL), let prefill_histogram =
latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16);
f.render_widget(prefill_histogram, mid[1]);
// Decode text info
let (decode_latency_statics, decode_throughput_statics) = text_info(
&mut decode_latencies[current_tab_idx],
&decode_throughputs[current_tab_idx],
"Decode",
); );
f.render_widget(decode_latency_statics, decode_text[0]); f.render_widget(decode_latency_statics, decode_text[0]);
let decode_throughput_statics = Paragraph::new(decode_throughput_texts).block(
Block::default()
.title(Span::raw("Decode Throughput"))
.borders(Borders::ALL),
);
f.render_widget(decode_throughput_statics, decode_text[1]); f.render_widget(decode_throughput_statics, decode_text[1]);
// Decode latency histogram
let histo_data = latency_histogram_data(&decode_latencies[current_tab_idx], bins);
let histo_data_str: Vec<(&str, u64)> =
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
let decode_histogram =
latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16);
f.render_widget(decode_histogram, mid[3]);
// Prefill latency/throughput chart
let prefill_latency_throughput_chart = latency_throughput_chart(
&prefill_batch_latency_throughput,
&self.batch_size,
"Prefill",
);
f.render_widget(prefill_latency_throughput_chart, bottom[0]);
// Decode latency/throughput chart
let decode_latency_throughput_chart = latency_throughput_chart(
&decode_batch_latency_throughput,
&self.batch_size,
"Decode",
);
f.render_widget(decode_latency_throughput_chart, bottom[1]);
})?; })?;
while crossterm::event::poll(Duration::from_secs(0))? { // Quit on q or CTRL+c
match crossterm::event::read()? {
Event::Key(KeyEvent { while event::poll(Duration::from_millis(100))? {
if let Event::Key(key) = event::read()? {
match key {
KeyEvent {
code: KeyCode::Right,
..
} => {
current_tab_idx = (current_tab_idx + 1) % self.batch_size.len();
}
KeyEvent {
code: KeyCode::Left,
..
} => {
if current_tab_idx > 0 {
current_tab_idx -= 1;
} else {
current_tab_idx = self.batch_size.len() - 1;
}
}
KeyEvent {
code: KeyCode::Char('q'), code: KeyCode::Char('q'),
.. ..
}) }
| Event::Key(KeyEvent { | KeyEvent {
code: KeyCode::Char('c'), code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL, modifiers: KeyModifiers::CONTROL,
.. ..
}) => { } => {
break 'outer; break 'outer;
} }
_ => (), _ => (),
} }
} }
}
// Frame budget
let per_frame = Duration::from_secs(1) / 30 as u32; let per_frame = Duration::from_secs(1) / 30 as u32;
let elapsed = frame_start.elapsed(); let elapsed = frame_start.elapsed();
if per_frame > elapsed { if per_frame > elapsed {
@ -200,6 +312,7 @@ impl UI {
} }
} }
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?; crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?; io::stdout().execute(crossterm::cursor::Show)?;
@ -209,13 +322,87 @@ impl UI {
} }
} }
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spans<'a>> { fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
Gauge::default()
.block(Block::default().title(title).borders(Borders::ALL))
.gauge_style(Style::default().fg(color))
.label(Span::raw(label))
.ratio(progress)
}
fn text_info<'a>(
latency: &mut Vec<f64>,
throughput: &Vec<f64>,
name: &'static str,
) -> (Paragraph<'a>, Paragraph<'a>) {
let mut latency_texts = statis_spans(&latency, "ms");
float_ord::sort(latency);
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
let span = Spans::from(vec![Span::styled(
format!("{name}: {:.4} ms", value),
Style::default().fg(colors[i]),
)]);
latency_texts.push(span);
}
let throughput_texts = statis_spans(&throughput, "tokens/secs");
let latency_statics = Paragraph::new(latency_texts).block(
Block::default()
.title(Span::raw(format!("{name} Latency")))
.borders(Borders::ALL),
);
let throughput_statics = Paragraph::new(throughput_texts).block(
Block::default()
.title(Span::raw(format!("{name} Throughput")))
.borders(Borders::ALL),
);
(latency_statics, throughput_statics)
}
fn latency_histogram_data(latency: &Vec<f64>, bins: usize) -> Vec<(String, u64)> {
let histo_data: Vec<(String, u64)> = {
let histo = crate::utils::histogram(latency, bins);
histo
.into_iter()
.map(|(label, v)| (format!("{label:.2}"), v as u64))
.collect()
};
histo_data
}
fn latency_histogram<'a>(
histo_data_str: &'a Vec<(&'a str, u64)>,
name: &'static str,
) -> BarChart<'a> {
BarChart::default()
.block(
Block::default()
.title(format!("{name} latency histogram"))
.style(Style::default().fg(Color::Yellow).bg(Color::Reset))
.borders(Borders::ALL),
)
.data(histo_data_str.as_slice())
}
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
vec![ vec![
Spans::from(vec![Span::styled(
format!(
"Average: {:.4} {unit}",
data.iter().sum::<f64>() / data.len() as f64
),
Style::default().fg(Color::LightBlue),
)]),
Spans::from(vec![Span::styled( Spans::from(vec![Span::styled(
format!( format!(
"Lowest: {:.4} {unit}", "Lowest: {:.4} {unit}",
data data.iter()
.iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&std::f64::NAN)
), ),
@ -224,22 +411,141 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spa
Spans::from(vec![Span::styled( Spans::from(vec![Span::styled(
format!( format!(
"Highest: {:.4} {unit}", "Highest: {:.4} {unit}",
data data.iter()
.iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&std::f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
Spans::from(vec![Span::styled( ]
format!( }
"Average: {:.4} {unit}",
data fn latency_throughput_chart<'a>(
.iter() latency_throughput: &'a Vec<(f64, f64)>,
.sum::<f64>() batch_sizes: &'a Vec<u32>,
/ data.len() as f64 name: &'static str,
), ) -> Chart<'a> {
Style::default().fg(Color::Reset), let latency_iter = latency_throughput.iter().map(|(l, _)| l);
)]), let throughput_iter = latency_throughput.iter().map(|(_, t)| t);
let min_latency: f64 = *latency_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let min_throughput: f64 = *throughput_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let min_x = ((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0;
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
let step_x = (max_x - min_x) / 4.0;
let min_y = ((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0;
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
let step_y = (max_y - min_y) / 4.0;
let mut x_labels = vec![Span::styled(
format!("{:.2}", min_x),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
)];
for i in 0..3 {
x_labels.push(Span::styled(
format!("{:.2}", min_x + ((i + 1) as f64 * step_x)),
Style::default().fg(Color::Gray).bg(Color::Reset),
));
}
x_labels.push(Span::styled(
format!("{:.2}", max_x),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
));
let mut y_labels = vec![Span::styled(
format!("{:.2}", min_y),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
)];
for i in 0..3 {
y_labels.push(Span::styled(
format!("{:.2}", min_y + ((i + 1) as f64 * step_y)),
Style::default().fg(Color::Gray).bg(Color::Reset),
));
}
y_labels.push(Span::styled(
format!("{:.2}", max_y),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
));
let colors = color_vec();
let datasets: Vec<Dataset> = (0..latency_throughput.len())
.map(|i| {
Dataset::default()
.name(batch_sizes[i].to_string())
.marker(symbols::Marker::Block)
.style(Style::default().fg(colors[i]))
.graph_type(GraphType::Scatter)
.data(&latency_throughput[i..(i + 1)])
})
.collect();
Chart::new(datasets)
.style(Style::default().fg(Color::Cyan).bg(Color::Reset))
.block(
Block::default()
.title(Span::styled(
format!("{name} throughput over latency"),
Style::default().fg(Color::Gray).bg(Color::Reset),
))
.borders(Borders::ALL),
)
.x_axis(
Axis::default()
.title(format!("ms"))
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
.labels(x_labels)
.bounds([min_x, max_x]),
)
.y_axis(
Axis::default()
.title(format!("tokens/secs"))
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
.labels(y_labels)
.bounds([min_y, max_y]),
)
}
fn color_vec() -> Vec<Color> {
vec![
Color::Red,
Color::Green,
Color::Yellow,
Color::Blue,
Color::Magenta,
Color::Cyan,
Color::Gray,
Color::DarkGray,
Color::LightRed,
Color::LightGreen,
Color::LightYellow,
Color::LightBlue,
Color::LightMagenta,
Color::LightCyan,
] ]
} }

43
benchmark/src/utils.rs Normal file
View File

@ -0,0 +1,43 @@
/// MIT License
//
// Copyright (c) 2020 hatoo
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
use std::collections::BTreeMap;
pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {
assert!(bins >= 2);
let mut bucket: Vec<usize> = vec![0; bins];
let min = values.iter().collect::<average::Min>().min();
let max = values.iter().collect::<average::Max>().max();
let step = (max - min) / (bins - 1) as f64;
for &v in values {
let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);
bucket[i] += 1;
}
bucket
.into_iter()
.enumerate()
.map(|(i, v)| (min + step * i as f64, v))
.collect()
}
pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {
pecents
.iter()
.map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
})
.collect()
}

View File

@ -37,7 +37,7 @@ struct Args {
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/text-generation-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
@ -76,6 +76,8 @@ fn main() -> Result<(), std::io::Error> {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
init_logging(otlp_endpoint, json_output);
// CORS allowed origins // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue // map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin // Finally, convert to AllowOrigin
@ -89,17 +91,21 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let tokenizer = let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{ {
// Load local tokenizer // Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else { } else {
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap() Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
}; };
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
@ -107,8 +113,6 @@ fn main() -> Result<(), std::io::Error> {
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output);
// Get pipeline tag // Get pipeline tag
let model_info = reqwest::get(format!( let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}" "https://huggingface.co/api/models/{tokenizer_name}"

View File

@ -315,7 +315,7 @@ fn validate(
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
stop_sequences, stop_sequences,
ignore_eos_token: false ignore_eos_token: false,
}; };
metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_input_length", input_length as f64);

View File

@ -18,7 +18,7 @@ def serve(
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
uds_path: Path = "/tmp/text-generation", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,

View File

@ -158,5 +158,8 @@ class StoppingCriteria:
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
return StoppingCriteria( return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token tokenizer.eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
) )