mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
v1
This commit is contained in:
parent
17a75c8845
commit
b6df2036ed
3
Makefile
3
Makefile
@ -7,6 +7,9 @@ install-router:
|
|||||||
install-launcher:
|
install-launcher:
|
||||||
cd launcher && cargo install --path .
|
cd launcher && cargo install --path .
|
||||||
|
|
||||||
|
install-benchmark:
|
||||||
|
cd benchmark && cargo install --path .
|
||||||
|
|
||||||
install: install-server install-router install-launcher
|
install: install-server install-router install-launcher
|
||||||
|
|
||||||
server-dev:
|
server-dev:
|
||||||
|
@ -9,7 +9,7 @@ description = "Text Generation Benchmarking tool"
|
|||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "text-generation-bench"
|
name = "text-generation-benchmark"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
24
benchmark/README.md
Normal file
24
benchmark/README.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Text Generation Inference benchmarking tool
|
||||||
|
|
||||||
|
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||||
|
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make install-benchmark
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
First, start `text-generation-inference`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
text-generation-launcher --model-id bigscience/bloom-560m
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the benchmarking tool:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
text-generation-benchmark --tokenizer-name bigscience/bloom-560m
|
||||||
|
```
|
@ -1,5 +1,5 @@
|
|||||||
|
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||||
use crate::generation::{Decode, Message, Prefill};
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
|
|
||||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
use text_generation_client::ClientError;
|
use text_generation_client::ClientError;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
@ -12,70 +12,8 @@ use tui::widgets::{
|
|||||||
};
|
};
|
||||||
use tui::{symbols, Frame};
|
use tui::{symbols, Frame};
|
||||||
|
|
||||||
struct Data {
|
/// TUI powered App
|
||||||
prefill_latencies: Vec<Vec<f64>>,
|
pub(crate) struct App {
|
||||||
prefill_throughputs: Vec<Vec<f64>>,
|
|
||||||
decode_latencies: Vec<Vec<f64>>,
|
|
||||||
decode_throughputs: Vec<Vec<f64>>,
|
|
||||||
prefill_batch_latency_throughput: Vec<(f64, f64)>,
|
|
||||||
decode_batch_latency_throughput: Vec<(f64, f64)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Data {
|
|
||||||
fn new(n_run: usize, n_batch: usize) -> Self {
|
|
||||||
let prefill_latencies: Vec<Vec<f64>> =
|
|
||||||
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
|
||||||
let prefill_throughputs: Vec<Vec<f64>> =
|
|
||||||
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
|
||||||
|
|
||||||
let decode_latencies: Vec<Vec<f64>> =
|
|
||||||
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
|
||||||
let decode_throughputs: Vec<Vec<f64>> =
|
|
||||||
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
|
||||||
|
|
||||||
let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
|
||||||
|
|
||||||
let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
|
||||||
|
|
||||||
Self {
|
|
||||||
prefill_latencies,
|
|
||||||
prefill_throughputs,
|
|
||||||
decode_latencies,
|
|
||||||
decode_throughputs,
|
|
||||||
prefill_batch_latency_throughput,
|
|
||||||
decode_batch_latency_throughput,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {
|
|
||||||
let latency = prefill.latency.as_millis() as f64;
|
|
||||||
self.prefill_latencies[batch_idx].push(latency);
|
|
||||||
self.prefill_throughputs[batch_idx].push(prefill.throughput);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push_decode(&mut self, prefill: Decode, batch_idx: usize) {
|
|
||||||
let latency = prefill.latency.as_millis() as f64;
|
|
||||||
self.decode_latencies[batch_idx].push(latency);
|
|
||||||
self.decode_throughputs[batch_idx].push(prefill.throughput);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn end_batch(&mut self, batch_idx: usize) {
|
|
||||||
self.prefill_batch_latency_throughput.push((
|
|
||||||
self.prefill_latencies[batch_idx].iter().sum::<f64>()
|
|
||||||
/ self.prefill_latencies[batch_idx].len() as f64,
|
|
||||||
self.prefill_throughputs[batch_idx].iter().sum::<f64>()
|
|
||||||
/ self.prefill_throughputs[batch_idx].len() as f64,
|
|
||||||
));
|
|
||||||
self.decode_batch_latency_throughput.push((
|
|
||||||
self.decode_latencies[batch_idx].iter().sum::<f64>()
|
|
||||||
/ self.decode_latencies[batch_idx].len() as f64,
|
|
||||||
self.decode_throughputs[batch_idx].iter().sum::<f64>()
|
|
||||||
/ self.decode_throughputs[batch_idx].len() as f64,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) struct UI {
|
|
||||||
pub(crate) running: bool,
|
pub(crate) running: bool,
|
||||||
completed_runs: Vec<usize>,
|
completed_runs: Vec<usize>,
|
||||||
completed_batch: usize,
|
completed_batch: usize,
|
||||||
@ -92,7 +30,7 @@ pub(crate) struct UI {
|
|||||||
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UI {
|
impl App {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
@ -127,18 +65,20 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle crossterm key events
|
||||||
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
|
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
|
||||||
match key_event {
|
match key_event {
|
||||||
|
// Increase and wrap tab
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Right,
|
code: KeyCode::Right,
|
||||||
..
|
..
|
||||||
} |
|
}
|
||||||
KeyEvent {
|
| KeyEvent {
|
||||||
code: KeyCode::Tab,
|
code: KeyCode::Tab, ..
|
||||||
..
|
|
||||||
} => {
|
} => {
|
||||||
self.current_tab = (self.current_tab + 1) % self.batch_size.len();
|
self.current_tab = (self.current_tab + 1) % self.batch_size.len();
|
||||||
}
|
}
|
||||||
|
// Decrease and wrap tab
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Left,
|
code: KeyCode::Left,
|
||||||
..
|
..
|
||||||
@ -149,19 +89,21 @@ impl UI {
|
|||||||
self.current_tab = self.batch_size.len() - 1;
|
self.current_tab = self.batch_size.len() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Zoom on throughput/latency fig
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Char('+'),
|
code: KeyCode::Char('+'),
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
self.zoom = true;
|
self.zoom = true;
|
||||||
}
|
}
|
||||||
|
// Unzoom on throughput/latency fig
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Char('-'),
|
code: KeyCode::Char('-'),
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
self.zoom = false;
|
self.zoom = false;
|
||||||
}
|
}
|
||||||
|
// Quit
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Char('q'),
|
code: KeyCode::Char('q'),
|
||||||
..
|
..
|
||||||
@ -177,13 +119,14 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get all pending messages from generation task
|
||||||
pub(crate) fn tick(&mut self) {
|
pub(crate) fn tick(&mut self) {
|
||||||
while let Ok(message) = self.receiver.try_recv() {
|
while let Ok(message) = self.receiver.try_recv() {
|
||||||
match message {
|
match message {
|
||||||
Ok(message) => match message {
|
Ok(message) => match message {
|
||||||
Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),
|
Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),
|
||||||
Message::Decode(step) => self.data.push_decode(step, self.current_batch),
|
Message::Decode(step) => self.data.push_decode(step, self.current_batch),
|
||||||
Message::Run(_) => {
|
Message::EndRun => {
|
||||||
self.completed_runs[self.current_batch] += 1;
|
self.completed_runs[self.current_batch] += 1;
|
||||||
}
|
}
|
||||||
Message::EndBatch => {
|
Message::EndBatch => {
|
||||||
@ -201,6 +144,7 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Render frame
|
||||||
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
||||||
let batch_progress =
|
let batch_progress =
|
||||||
(self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
|
(self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||||
@ -277,14 +221,9 @@ impl UI {
|
|||||||
// Helper
|
// Helper
|
||||||
let helper = Block::default()
|
let helper = Block::default()
|
||||||
.borders(Borders::NONE)
|
.borders(Borders::NONE)
|
||||||
.title(format!(
|
.title("<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom")
|
||||||
"<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom"
|
|
||||||
))
|
|
||||||
.title_alignment(Alignment::Right)
|
.title_alignment(Alignment::Right)
|
||||||
.style(
|
.style(Style::default().fg(Color::White));
|
||||||
Style::default()
|
|
||||||
.fg(Color::White),
|
|
||||||
);
|
|
||||||
f.render_widget(helper, row5[0]);
|
f.render_widget(helper, row5[0]);
|
||||||
|
|
||||||
// Batch tabs
|
// Batch tabs
|
||||||
@ -404,6 +343,71 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// App internal data struct
|
||||||
|
struct Data {
|
||||||
|
prefill_latencies: Vec<Vec<f64>>,
|
||||||
|
prefill_throughputs: Vec<Vec<f64>>,
|
||||||
|
decode_latencies: Vec<Vec<f64>>,
|
||||||
|
decode_throughputs: Vec<Vec<f64>>,
|
||||||
|
prefill_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
decode_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Data {
|
||||||
|
fn new(n_run: usize, n_batch: usize) -> Self {
|
||||||
|
let prefill_latencies: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
let prefill_throughputs: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
|
||||||
|
let decode_latencies: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
let decode_throughputs: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
|
||||||
|
let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
||||||
|
|
||||||
|
let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
prefill_latencies,
|
||||||
|
prefill_throughputs,
|
||||||
|
decode_latencies,
|
||||||
|
decode_throughputs,
|
||||||
|
prefill_batch_latency_throughput,
|
||||||
|
decode_batch_latency_throughput,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {
|
||||||
|
let latency = prefill.latency.as_millis() as f64;
|
||||||
|
self.prefill_latencies[batch_idx].push(latency);
|
||||||
|
self.prefill_throughputs[batch_idx].push(prefill.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_decode(&mut self, prefill: Decode, batch_idx: usize) {
|
||||||
|
let latency = prefill.latency.as_millis() as f64;
|
||||||
|
self.decode_latencies[batch_idx].push(latency);
|
||||||
|
self.decode_throughputs[batch_idx].push(prefill.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_batch(&mut self, batch_idx: usize) {
|
||||||
|
self.prefill_batch_latency_throughput.push((
|
||||||
|
self.prefill_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_latencies[batch_idx].len() as f64,
|
||||||
|
self.prefill_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
self.decode_batch_latency_throughput.push((
|
||||||
|
self.decode_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_latencies[batch_idx].len() as f64,
|
||||||
|
self.decode_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Progress bar
|
||||||
fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
|
fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
|
||||||
Gauge::default()
|
Gauge::default()
|
||||||
.block(Block::default().title(title).borders(Borders::ALL))
|
.block(Block::default().title(title).borders(Borders::ALL))
|
||||||
@ -412,31 +416,40 @@ fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Ga
|
|||||||
.ratio(progress)
|
.ratio(progress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prefill or Decode text infos
|
||||||
fn text_info<'a>(
|
fn text_info<'a>(
|
||||||
latency: &mut Vec<f64>,
|
latency: &mut Vec<f64>,
|
||||||
throughput: &Vec<f64>,
|
throughput: &Vec<f64>,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> (Paragraph<'a>, Paragraph<'a>) {
|
) -> (Paragraph<'a>, Paragraph<'a>) {
|
||||||
let mut latency_texts = statis_spans(&latency, "ms");
|
// Latency average/high/low texts
|
||||||
|
let mut latency_texts = statis_spans(latency, "ms");
|
||||||
|
|
||||||
|
// Sort latency for percentiles
|
||||||
float_ord::sort(latency);
|
float_ord::sort(latency);
|
||||||
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
|
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
|
||||||
|
|
||||||
|
// Latency p50/p90/p99 texts
|
||||||
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
|
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
|
||||||
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
|
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
|
||||||
let span = Spans::from(vec![Span::styled(
|
let span = Spans::from(vec![Span::styled(
|
||||||
format!("{name}: {:.4} ms", value),
|
format!("{name}: {value:.4} ms"),
|
||||||
Style::default().fg(colors[i]),
|
Style::default().fg(colors[i]),
|
||||||
)]);
|
)]);
|
||||||
latency_texts.push(span);
|
latency_texts.push(span);
|
||||||
}
|
}
|
||||||
|
|
||||||
let throughput_texts = statis_spans(&throughput, "tokens/secs");
|
// Throughput average/high/low texts
|
||||||
|
let throughput_texts = statis_spans(throughput, "tokens/secs");
|
||||||
|
|
||||||
|
// Latency Block
|
||||||
let latency_statics = Paragraph::new(latency_texts).block(
|
let latency_statics = Paragraph::new(latency_texts).block(
|
||||||
Block::default()
|
Block::default()
|
||||||
.title(Span::raw(format!("{name} Latency")))
|
.title(Span::raw(format!("{name} Latency")))
|
||||||
.borders(Borders::ALL),
|
.borders(Borders::ALL),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Throughput block
|
||||||
let throughput_statics = Paragraph::new(throughput_texts).block(
|
let throughput_statics = Paragraph::new(throughput_texts).block(
|
||||||
Block::default()
|
Block::default()
|
||||||
.title(Span::raw(format!("{name} Throughput")))
|
.title(Span::raw(format!("{name} Throughput")))
|
||||||
@ -446,32 +459,7 @@ fn text_info<'a>(
|
|||||||
(latency_statics, throughput_statics)
|
(latency_statics, throughput_statics)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn latency_histogram_data(latency: &Vec<f64>, bins: usize) -> Vec<(String, u64)> {
|
/// Average/High/Low spans
|
||||||
let histo_data: Vec<(String, u64)> = {
|
|
||||||
let histo = crate::utils::histogram(latency, bins);
|
|
||||||
histo
|
|
||||||
.into_iter()
|
|
||||||
.map(|(label, v)| (format!("{label:.2}"), v as u64))
|
|
||||||
.collect()
|
|
||||||
};
|
|
||||||
|
|
||||||
histo_data
|
|
||||||
}
|
|
||||||
|
|
||||||
fn latency_histogram<'a>(
|
|
||||||
histo_data_str: &'a Vec<(&'a str, u64)>,
|
|
||||||
name: &'static str,
|
|
||||||
) -> BarChart<'a> {
|
|
||||||
BarChart::default()
|
|
||||||
.block(
|
|
||||||
Block::default()
|
|
||||||
.title(format!("{name} latency histogram"))
|
|
||||||
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
|
|
||||||
.borders(Borders::ALL),
|
|
||||||
)
|
|
||||||
.data(histo_data_str.as_slice())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
|
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
|
||||||
vec![
|
vec![
|
||||||
Spans::from(vec![Span::styled(
|
Spans::from(vec![Span::styled(
|
||||||
@ -502,15 +490,45 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Latency histogram data
|
||||||
|
fn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> {
|
||||||
|
let histo_data: Vec<(String, u64)> = {
|
||||||
|
let histo = crate::utils::histogram(latency, bins);
|
||||||
|
histo
|
||||||
|
.into_iter()
|
||||||
|
.map(|(label, v)| (format!("{label:.2}"), v as u64))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
histo_data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency Histogram
|
||||||
|
fn latency_histogram<'a>(
|
||||||
|
histo_data_str: &'a Vec<(&'a str, u64)>,
|
||||||
|
name: &'static str,
|
||||||
|
) -> BarChart<'a> {
|
||||||
|
BarChart::default()
|
||||||
|
.block(
|
||||||
|
Block::default()
|
||||||
|
.title(format!("{name} latency histogram"))
|
||||||
|
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
)
|
||||||
|
.data(histo_data_str.as_slice())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency/Throughput chart
|
||||||
fn latency_throughput_chart<'a>(
|
fn latency_throughput_chart<'a>(
|
||||||
latency_throughput: &'a Vec<(f64, f64)>,
|
latency_throughput: &'a Vec<(f64, f64)>,
|
||||||
batch_sizes: &'a Vec<u32>,
|
batch_sizes: &'a [u32],
|
||||||
zoom: bool,
|
zoom: bool,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> Chart<'a> {
|
) -> Chart<'a> {
|
||||||
let latency_iter = latency_throughput.iter().map(|(l, _)| l);
|
let latency_iter = latency_throughput.iter().map(|(l, _)| l);
|
||||||
let throughput_iter = latency_throughput.iter().map(|(_, t)| t);
|
let throughput_iter = latency_throughput.iter().map(|(_, t)| t);
|
||||||
|
|
||||||
|
// Get extreme values
|
||||||
let min_latency: f64 = *latency_iter
|
let min_latency: f64 = *latency_iter
|
||||||
.clone()
|
.clone()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
@ -526,6 +544,7 @@ fn latency_throughput_chart<'a>(
|
|||||||
.max_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN);
|
.unwrap_or(&std::f64::NAN);
|
||||||
|
|
||||||
|
// Char min max values
|
||||||
let min_x = if zoom {
|
let min_x = if zoom {
|
||||||
((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0
|
((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0
|
||||||
} else {
|
} else {
|
||||||
@ -534,6 +553,7 @@ fn latency_throughput_chart<'a>(
|
|||||||
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
|
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
|
||||||
let step_x = (max_x - min_x) / 4.0;
|
let step_x = (max_x - min_x) / 4.0;
|
||||||
|
|
||||||
|
// Chart min max values
|
||||||
let min_y = if zoom {
|
let min_y = if zoom {
|
||||||
((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0
|
((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0
|
||||||
} else {
|
} else {
|
||||||
@ -542,8 +562,9 @@ fn latency_throughput_chart<'a>(
|
|||||||
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
|
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
|
||||||
let step_y = (max_y - min_y) / 4.0;
|
let step_y = (max_y - min_y) / 4.0;
|
||||||
|
|
||||||
|
// Labels
|
||||||
let mut x_labels = vec![Span::styled(
|
let mut x_labels = vec![Span::styled(
|
||||||
format!("{:.2}", min_x),
|
format!("{min_x:.2}"),
|
||||||
Style::default()
|
Style::default()
|
||||||
.add_modifier(Modifier::BOLD)
|
.add_modifier(Modifier::BOLD)
|
||||||
.fg(Color::Gray)
|
.fg(Color::Gray)
|
||||||
@ -556,15 +577,16 @@ fn latency_throughput_chart<'a>(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
x_labels.push(Span::styled(
|
x_labels.push(Span::styled(
|
||||||
format!("{:.2}", max_x),
|
format!("{max_x:.2}"),
|
||||||
Style::default()
|
Style::default()
|
||||||
.add_modifier(Modifier::BOLD)
|
.add_modifier(Modifier::BOLD)
|
||||||
.fg(Color::Gray)
|
.fg(Color::Gray)
|
||||||
.bg(Color::Reset),
|
.bg(Color::Reset),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Labels
|
||||||
let mut y_labels = vec![Span::styled(
|
let mut y_labels = vec![Span::styled(
|
||||||
format!("{:.2}", min_y),
|
format!("{min_y:.2}"),
|
||||||
Style::default()
|
Style::default()
|
||||||
.add_modifier(Modifier::BOLD)
|
.add_modifier(Modifier::BOLD)
|
||||||
.fg(Color::Gray)
|
.fg(Color::Gray)
|
||||||
@ -577,25 +599,29 @@ fn latency_throughput_chart<'a>(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
y_labels.push(Span::styled(
|
y_labels.push(Span::styled(
|
||||||
format!("{:.2}", max_y),
|
format!("{max_y:.2}"),
|
||||||
Style::default()
|
Style::default()
|
||||||
.add_modifier(Modifier::BOLD)
|
.add_modifier(Modifier::BOLD)
|
||||||
.fg(Color::Gray)
|
.fg(Color::Gray)
|
||||||
.bg(Color::Reset),
|
.bg(Color::Reset),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Chart dataset
|
||||||
let colors = color_vec();
|
let colors = color_vec();
|
||||||
let datasets: Vec<Dataset> = (0..latency_throughput.len())
|
let datasets: Vec<Dataset> = (0..latency_throughput.len())
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
|
let color_idx = i % colors.len();
|
||||||
|
|
||||||
Dataset::default()
|
Dataset::default()
|
||||||
.name(batch_sizes[i].to_string())
|
.name(batch_sizes[i].to_string())
|
||||||
.marker(symbols::Marker::Block)
|
.marker(symbols::Marker::Block)
|
||||||
.style(Style::default().fg(colors[i]))
|
.style(Style::default().fg(colors[color_idx]))
|
||||||
.graph_type(GraphType::Scatter)
|
.graph_type(GraphType::Scatter)
|
||||||
.data(&latency_throughput[i..(i + 1)])
|
.data(&latency_throughput[i..(i + 1)])
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Chart
|
||||||
Chart::new(datasets)
|
Chart::new(datasets)
|
||||||
.style(Style::default().fg(Color::Cyan).bg(Color::Reset))
|
.style(Style::default().fg(Color::Cyan).bg(Color::Reset))
|
||||||
.block(
|
.block(
|
||||||
@ -608,20 +634,21 @@ fn latency_throughput_chart<'a>(
|
|||||||
)
|
)
|
||||||
.x_axis(
|
.x_axis(
|
||||||
Axis::default()
|
Axis::default()
|
||||||
.title(format!("ms"))
|
.title("ms")
|
||||||
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
||||||
.labels(x_labels)
|
.labels(x_labels)
|
||||||
.bounds([min_x, max_x]),
|
.bounds([min_x, max_x]),
|
||||||
)
|
)
|
||||||
.y_axis(
|
.y_axis(
|
||||||
Axis::default()
|
Axis::default()
|
||||||
.title(format!("tokens/secs"))
|
.title("tokens/secs")
|
||||||
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
||||||
.labels(y_labels)
|
.labels(y_labels)
|
||||||
.bounds([min_y, max_y]),
|
.bounds([min_y, max_y]),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Colors for latency/throughput chart
|
||||||
fn color_vec() -> Vec<Color> {
|
fn color_vec() -> Vec<Color> {
|
||||||
vec![
|
vec![
|
||||||
Color::Red,
|
Color::Red,
|
@ -1,4 +1,4 @@
|
|||||||
/// Inspired by https://github.com/orhun/rust-tui-template
|
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||||
use crossterm::event;
|
use crossterm::event;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
@ -20,6 +20,8 @@ pub(crate) async fn terminal_event_task(
|
|||||||
mut shutdown_receiver: broadcast::Receiver<()>,
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||||
_shutdown_guard_sender: mpsc::Sender<()>,
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
|
// End task if a message is received on shutdown_receiver
|
||||||
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = event_loop(fps, event_sender) => {
|
_ = event_loop(fps, event_sender) => {
|
||||||
},
|
},
|
||||||
@ -27,14 +29,21 @@ pub(crate) async fn terminal_event_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Main event loop
|
||||||
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
|
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
|
||||||
let per_frame = Duration::from_secs(1) / fps as u32;
|
// Frame budget
|
||||||
|
let per_frame = Duration::from_secs(1) / fps;
|
||||||
|
|
||||||
|
// When was last frame executed
|
||||||
let mut last_frame = Instant::now();
|
let mut last_frame = Instant::now();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// Sleep to avoid blocking the thread for too long
|
||||||
if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {
|
if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {
|
||||||
tokio::time::sleep(sleep).await;
|
tokio::time::sleep(sleep).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get crossterm event and send a new one over the channel
|
||||||
if event::poll(Duration::from_secs(0)).expect("no events available") {
|
if event::poll(Duration::from_secs(0)).expect("no events available") {
|
||||||
match event::read().expect("unable to read event") {
|
match event::read().expect("unable to read event") {
|
||||||
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
|
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
|
||||||
@ -45,8 +54,11 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Frame budget exceeded
|
||||||
if last_frame.elapsed() >= per_frame {
|
if last_frame.elapsed() >= per_frame {
|
||||||
|
// Send tick
|
||||||
event_sender.send(Event::Tick).await.unwrap_or(());
|
event_sender.send(Event::Tick).await.unwrap_or(());
|
||||||
|
// Rest last_frame time
|
||||||
last_frame = Instant::now();
|
last_frame = Instant::now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,28 +16,21 @@ pub(crate) struct Prefill {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Decode {
|
pub(crate) struct Decode {
|
||||||
pub(crate) decode_length: u32,
|
|
||||||
pub(crate) latency: Duration,
|
pub(crate) latency: Duration,
|
||||||
pub(crate) throughput: f64,
|
pub(crate) throughput: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct Run {
|
|
||||||
pub(crate) batch_size: u32,
|
|
||||||
pub(crate) sequence_length: u32,
|
|
||||||
pub(crate) prefill: Prefill,
|
|
||||||
pub(crate) decode: Decode,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum Message {
|
pub(crate) enum Message {
|
||||||
Warmup,
|
Warmup,
|
||||||
Prefill(Prefill),
|
Prefill(Prefill),
|
||||||
Decode(Decode),
|
Decode(Decode),
|
||||||
Run(Run),
|
EndRun,
|
||||||
EndBatch,
|
EndBatch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Benchmarking task
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn generation_task(
|
pub(crate) async fn generation_task(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
@ -50,6 +43,8 @@ pub(crate) async fn generation_task(
|
|||||||
mut shutdown_receiver: broadcast::Receiver<()>,
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||||
_shutdown_guard_sender: mpsc::Sender<()>,
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
|
// End task if a message is received on shutdown_receiver
|
||||||
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
|
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
|
||||||
if let Err(err) = res {
|
if let Err(err) = res {
|
||||||
@ -60,6 +55,8 @@ pub(crate) async fn generation_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Benchmark prefill/decode
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn generate_runs(
|
async fn generate_runs(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
@ -70,52 +67,53 @@ async fn generate_runs(
|
|||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||||
) -> Result<(), ClientError> {
|
) -> Result<(), ClientError> {
|
||||||
|
// Create a dummy sequence
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
|
// Warmups on batch size
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
let (_, decode_batch) =
|
let (_, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
let _ = decode(decode_batch, &mut client).await?;
|
let _ = decode(decode_batch, &mut client).await?;
|
||||||
|
// Send warmup message
|
||||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
|
|
||||||
for _ in 0..n_runs {
|
for _ in 0..n_runs {
|
||||||
let (prefill, decode_batch) =
|
let (prefill, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
|
// Send prefill message
|
||||||
run_sender
|
run_sender
|
||||||
.send(Ok(Message::Prefill(prefill.clone())))
|
.send(Ok(Message::Prefill(prefill)))
|
||||||
.await
|
.await
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
|
||||||
let decode = decode(decode_batch, &mut client).await?;
|
let decode = decode(decode_batch, &mut client).await?;
|
||||||
|
|
||||||
|
// Send decode message
|
||||||
run_sender
|
run_sender
|
||||||
.send(Ok(Message::Decode(decode.clone())))
|
.send(Ok(Message::Decode(decode)))
|
||||||
.await
|
.await
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
|
||||||
run_sender
|
// Send run ended message
|
||||||
.send(Ok(Message::Run(Run {
|
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
|
||||||
batch_size: b,
|
|
||||||
sequence_length,
|
|
||||||
prefill,
|
|
||||||
decode,
|
|
||||||
})))
|
|
||||||
.await
|
|
||||||
.unwrap_or(());
|
|
||||||
}
|
}
|
||||||
|
// Batch ended
|
||||||
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run a prefill step
|
||||||
async fn prefill(
|
async fn prefill(
|
||||||
sequence: String,
|
sequence: String,
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
) -> Result<(Prefill, Batch), ClientError> {
|
) -> Result<(Prefill, Batch), ClientError> {
|
||||||
|
// Create requests
|
||||||
let requests = (0..batch_size)
|
let requests = (0..batch_size)
|
||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
@ -133,7 +131,7 @@ async fn prefill(
|
|||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: decode_length,
|
max_new_tokens: decode_length,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@ -144,11 +142,17 @@ async fn prefill(
|
|||||||
size: batch_size,
|
size: batch_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Run prefill
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||||
|
|
||||||
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
|
||||||
|
// Compute throughput from latency and batch size
|
||||||
let throughput = batch_size as f64 / latency.as_secs_f64();
|
let throughput = batch_size as f64 / latency.as_secs_f64();
|
||||||
|
|
||||||
|
// Decode batch cannot be empty
|
||||||
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||||
|
|
||||||
let step = Prefill {
|
let step = Prefill {
|
||||||
@ -159,28 +163,35 @@ async fn prefill(
|
|||||||
Ok((step, decode_batch))
|
Ok((step, decode_batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run a full decode
|
||||||
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||||
let mut decode_length = 0;
|
let mut decode_length = 0;
|
||||||
let start_time = Instant::now();
|
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
|
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
// Full decode over decode length
|
||||||
let mut next_batch = Some(batch);
|
let mut next_batch = Some(batch);
|
||||||
while let Some(batch) = next_batch {
|
while let Some(batch) = next_batch {
|
||||||
let result = client.decode(vec![batch]).await?;
|
let result = client.decode(vec![batch]).await?;
|
||||||
next_batch = result.1;
|
next_batch = result.1;
|
||||||
decode_length += 1;
|
decode_length += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
|
||||||
|
// Compute throughput from latency, batch size and decode length
|
||||||
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
|
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
|
||||||
|
|
||||||
let step = Decode {
|
let step = Decode {
|
||||||
decode_length,
|
|
||||||
latency,
|
latency,
|
||||||
throughput,
|
throughput,
|
||||||
};
|
};
|
||||||
Ok(step)
|
Ok(step)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a dummy sequence of the correct length
|
||||||
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||||
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||||
// Repeat lorem ipsum to cover sequence length
|
// Repeat lorem ipsum to cover sequence length
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
extern crate core;
|
mod app;
|
||||||
|
|
||||||
mod event;
|
mod event;
|
||||||
mod generation;
|
mod generation;
|
||||||
mod ui;
|
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crate::ui::UI;
|
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
@ -15,6 +13,8 @@ use tokio::sync::{broadcast, mpsc};
|
|||||||
use tui::backend::CrosstermBackend;
|
use tui::backend::CrosstermBackend;
|
||||||
use tui::Terminal;
|
use tui::Terminal;
|
||||||
|
|
||||||
|
/// Run benchmarking app
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
@ -25,11 +25,27 @@ pub async fn run(
|
|||||||
warmups: usize,
|
warmups: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
) -> Result<(), crossterm::ErrorKind> {
|
) -> Result<(), crossterm::ErrorKind> {
|
||||||
|
// Initialize terminal properties
|
||||||
|
crossterm::terminal::enable_raw_mode()?;
|
||||||
|
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||||
|
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||||
|
|
||||||
|
// Initialize terminal
|
||||||
|
let mut terminal = {
|
||||||
|
let backend = CrosstermBackend::new(io::stdout());
|
||||||
|
Terminal::new(backend)?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create message channel between generation_task and app
|
||||||
let (run_sender, run_receiver) = mpsc::channel(8);
|
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||||
|
// Crossterm event channel
|
||||||
let (event_sender, mut event_receiver) = mpsc::channel(8);
|
let (event_sender, mut event_receiver) = mpsc::channel(8);
|
||||||
|
// Shutdown channel to terminate tasks
|
||||||
let (shutdown_sender, _) = broadcast::channel(1);
|
let (shutdown_sender, _) = broadcast::channel(1);
|
||||||
|
// Channel to check if tasks terminated
|
||||||
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
||||||
|
|
||||||
|
// Create generation task
|
||||||
tokio::spawn(generation::generation_task(
|
tokio::spawn(generation::generation_task(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch_size.clone(),
|
batch_size.clone(),
|
||||||
@ -43,6 +59,7 @@ pub async fn run(
|
|||||||
shutdown_guard_sender.clone(),
|
shutdown_guard_sender.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Create event task
|
||||||
tokio::spawn(event::terminal_event_task(
|
tokio::spawn(event::terminal_event_task(
|
||||||
250,
|
250,
|
||||||
event_sender,
|
event_sender,
|
||||||
@ -50,9 +67,11 @@ pub async fn run(
|
|||||||
shutdown_guard_sender.clone(),
|
shutdown_guard_sender.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Drop our end of shutdown sender
|
||||||
drop(shutdown_guard_sender);
|
drop(shutdown_guard_sender);
|
||||||
|
|
||||||
let mut ui = UI::new(
|
// Create App
|
||||||
|
let mut app = App::new(
|
||||||
run_receiver,
|
run_receiver,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
@ -61,23 +80,17 @@ pub async fn run(
|
|||||||
batch_size,
|
batch_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
crossterm::terminal::enable_raw_mode()?;
|
while app.running {
|
||||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
// Draw frame
|
||||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
terminal.draw(|frame| app.render(frame))?;
|
||||||
|
|
||||||
let mut terminal = {
|
|
||||||
let backend = CrosstermBackend::new(io::stdout());
|
|
||||||
Terminal::new(backend)?
|
|
||||||
};
|
|
||||||
|
|
||||||
while ui.running {
|
|
||||||
terminal.draw(|frame| ui.render(frame))?;
|
|
||||||
|
|
||||||
|
// Await a new event from event handling task
|
||||||
match event_receiver.recv().await {
|
match event_receiver.recv().await {
|
||||||
None => break,
|
None => break,
|
||||||
|
// Update app state
|
||||||
Some(event) => match event {
|
Some(event) => match event {
|
||||||
Event::Tick => ui.tick(),
|
Event::Tick => app.tick(),
|
||||||
Event::Key(key_event) => ui.handle_key_event(key_event),
|
Event::Key(key_event) => app.handle_key_event(key_event),
|
||||||
_ => {}
|
_ => {}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
use clap::Parser;
|
|
||||||
/// Text Generation Inference benchmarking tool
|
/// Text Generation Inference benchmarking tool
|
||||||
|
///
|
||||||
|
/// Inspired by the great Oha app: https://github.com/hatoo/oha
|
||||||
|
/// and: https://github.com/orhun/rust-tui-template
|
||||||
|
use clap::Parser;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -11,17 +14,17 @@ use tracing_subscriber::EnvFilter;
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(default_value = "1", long, env)]
|
#[clap(long)]
|
||||||
batch_size: Vec<u32>,
|
batch_size: Option<Vec<u32>>,
|
||||||
#[clap(default_value = "10", long, env)]
|
#[clap(default_value = "10", long, env)]
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
#[clap(default_value = "64", long, env)]
|
#[clap(default_value = "64", long, env)]
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
#[clap(default_value = "10", long, env)]
|
#[clap(default_value = "10", long, env)]
|
||||||
runs: usize,
|
runs: usize,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "1", long, env)]
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
@ -41,6 +44,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||||
|
|
||||||
init_logging();
|
init_logging();
|
||||||
|
|
||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
@ -79,6 +84,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
// Run app
|
||||||
text_generation_benchmark::run(
|
text_generation_benchmark::run(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user