mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-21 16:40:20 +00:00
improving design
This commit is contained in:
parent
ae72d4f96f
commit
271f045825
@ -1,7 +1,7 @@
|
|||||||
/// Inspired by https://github.com/orhun/rust-tui-template
|
/// Inspired by https://github.com/orhun/rust-tui-template
|
||||||
use crossterm::event;
|
use crossterm::event;
|
||||||
use tokio::sync::{mpsc, broadcast};
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
/// Events
|
/// Events
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -14,9 +14,11 @@ pub(crate) enum Event {
|
|||||||
Resize(u16, u16),
|
Resize(u16, u16),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn terminal_event_task(fps: u32, event_sender: mpsc::Sender<Event>,
|
pub(crate) async fn terminal_event_task(
|
||||||
mut shutdown_receiver: broadcast::Receiver<()>,
|
fps: u32,
|
||||||
_shutdown_guard_sender: mpsc::Sender<()>,
|
event_sender: mpsc::Sender<Event>,
|
||||||
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||||
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = event_loop(fps, event_sender) => {
|
_ = event_loop(fps, event_sender) => {
|
||||||
@ -25,8 +27,7 @@ pub(crate) async fn terminal_event_task(fps: u32, event_sender: mpsc::Sender<Eve
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>,
|
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
|
||||||
) {
|
|
||||||
let per_frame = Duration::from_secs(1) / fps as u32;
|
let per_frame = Duration::from_secs(1) / fps as u32;
|
||||||
let mut last_frame = Instant::now();
|
let mut last_frame = Instant::now();
|
||||||
loop {
|
loop {
|
||||||
@ -37,7 +38,9 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>,
|
|||||||
if event::poll(Duration::from_secs(0)).expect("no events available") {
|
if event::poll(Duration::from_secs(0)).expect("no events available") {
|
||||||
match event::read().expect("unable to read event") {
|
match event::read().expect("unable to read event") {
|
||||||
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
|
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
|
||||||
event::Event::Resize(w, h) => event_sender.send(Event::Resize(w, h)).await.unwrap_or(()),
|
event::Event::Resize(w, h) => {
|
||||||
|
event_sender.send(Event::Resize(w, h)).await.unwrap_or(())
|
||||||
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use text_generation_client::{Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters};
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
@ -57,26 +60,29 @@ pub(crate) async fn generation_task(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_runs(tokenizer: Tokenizer,
|
async fn generate_runs(
|
||||||
batch_size: Vec<u32>,
|
tokenizer: Tokenizer,
|
||||||
sequence_length: u32,
|
batch_size: Vec<u32>,
|
||||||
decode_length: u32,
|
sequence_length: u32,
|
||||||
n_runs: usize,
|
decode_length: u32,
|
||||||
warmups: usize,
|
n_runs: usize,
|
||||||
mut client: ShardedClient,
|
warmups: usize,
|
||||||
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
mut client: ShardedClient,
|
||||||
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||||
) -> Result<(), ClientError> {
|
) -> Result<(), ClientError> {
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
let (_, decode_batch) =
|
||||||
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
let _ = decode(decode_batch, &mut client).await?;
|
let _ = decode(decode_batch, &mut client).await?;
|
||||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
|
|
||||||
for _ in 0..n_runs {
|
for _ in 0..n_runs {
|
||||||
let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
let (prefill, decode_batch) =
|
||||||
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
run_sender
|
run_sender
|
||||||
.send(Ok(Message::Prefill(prefill.clone())))
|
.send(Ok(Message::Prefill(prefill.clone())))
|
||||||
.await
|
.await
|
||||||
@ -89,12 +95,15 @@ async fn generate_runs(tokenizer: Tokenizer,
|
|||||||
.await
|
.await
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
|
||||||
run_sender.send(Ok(Message::Run(Run {
|
run_sender
|
||||||
batch_size: b,
|
.send(Ok(Message::Run(Run {
|
||||||
sequence_length,
|
batch_size: b,
|
||||||
prefill,
|
sequence_length,
|
||||||
decode,
|
prefill,
|
||||||
}))).await.unwrap_or(());
|
decode,
|
||||||
|
})))
|
||||||
|
.await
|
||||||
|
.unwrap_or(());
|
||||||
}
|
}
|
||||||
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
@ -138,8 +147,7 @@ async fn prefill(
|
|||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
let throughput = batch_size as f64
|
let throughput = batch_size as f64 / latency.as_secs_f64();
|
||||||
/ latency.as_secs_f64();
|
|
||||||
|
|
||||||
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||||
|
|
||||||
@ -151,10 +159,7 @@ async fn prefill(
|
|||||||
Ok((step, decode_batch))
|
Ok((step, decode_batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn decode(
|
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||||
batch: Batch,
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
) -> Result<Decode, ClientError> {
|
|
||||||
let mut decode_length = 0;
|
let mut decode_length = 0;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
@ -166,8 +171,7 @@ async fn decode(
|
|||||||
decode_length += 1;
|
decode_length += 1;
|
||||||
}
|
}
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
let throughput = (batch_size * decode_length) as f64
|
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
|
||||||
/ latency.as_secs_f64();
|
|
||||||
|
|
||||||
let step = Decode {
|
let step = Decode {
|
||||||
decode_length,
|
decode_length,
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
extern crate core;
|
extern crate core;
|
||||||
|
|
||||||
|
mod event;
|
||||||
|
mod generation;
|
||||||
mod ui;
|
mod ui;
|
||||||
mod utils;
|
mod utils;
|
||||||
mod generation;
|
|
||||||
mod event;
|
|
||||||
|
|
||||||
|
use crate::event::Event;
|
||||||
use crate::ui::UI;
|
use crate::ui::UI;
|
||||||
|
use crossterm::ExecutableCommand;
|
||||||
|
use std::io;
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use text_generation_client::ShardedClient;
|
use tui::backend::CrosstermBackend;
|
||||||
|
use tui::Terminal;
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
@ -20,33 +24,74 @@ pub async fn run(
|
|||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), crossterm::ErrorKind> {
|
||||||
let (run_sender, run_receiver) = mpsc::channel(8);
|
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||||
let (shutdown_sender, shutdown_receiver) = broadcast::channel(1);
|
let (event_sender, mut event_receiver) = mpsc::channel(8);
|
||||||
|
let (shutdown_sender, _) = broadcast::channel(1);
|
||||||
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(generation::generation_task(
|
||||||
generation::generation_task(tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, client, run_sender, shutdown_receiver, shutdown_guard_sender.clone()),
|
tokenizer,
|
||||||
|
batch_size.clone(),
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_runs,
|
||||||
|
warmups,
|
||||||
|
client,
|
||||||
|
run_sender,
|
||||||
|
shutdown_sender.subscribe(),
|
||||||
|
shutdown_guard_sender.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::spawn(event::terminal_event_task(
|
||||||
|
250,
|
||||||
|
event_sender,
|
||||||
|
shutdown_sender.subscribe(),
|
||||||
|
shutdown_guard_sender.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
drop(shutdown_guard_sender);
|
||||||
|
|
||||||
|
let mut ui = UI::new(
|
||||||
|
run_receiver,
|
||||||
|
tokenizer_name,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_runs,
|
||||||
|
batch_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
tokio::spawn(
|
crossterm::terminal::enable_raw_mode()?;
|
||||||
UI {
|
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||||
tokenizer_name,
|
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||||
decode_length,
|
|
||||||
sequence_length,
|
let mut terminal = {
|
||||||
n_run: n_runs,
|
let backend = CrosstermBackend::new(io::stdout());
|
||||||
batch_size: batch_size,
|
Terminal::new(backend)?
|
||||||
receiver: run_receiver,
|
};
|
||||||
shutdown_sender,
|
|
||||||
_shutdown_guard_sender: shutdown_guard_sender.clone()
|
while ui.running {
|
||||||
|
terminal.draw(|frame| ui.render(frame))?;
|
||||||
|
|
||||||
|
match event_receiver.recv().await {
|
||||||
|
None => break,
|
||||||
|
Some(event) => match event {
|
||||||
|
Event::Tick => ui.tick(),
|
||||||
|
Event::Key(key_event) => ui.handle_key_event(key_event),
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
.draw(),
|
}
|
||||||
);
|
|
||||||
|
|
||||||
drop (shutdown_guard_sender);
|
|
||||||
|
|
||||||
|
// Ask tasks to shutdown
|
||||||
|
let _ = shutdown_sender.send(());
|
||||||
// Wait for tasks to shutdown
|
// Wait for tasks to shutdown
|
||||||
let _ = shutdown_guard_receiver.recv().await;
|
let _ = shutdown_guard_receiver.recv().await;
|
||||||
|
|
||||||
|
// Revert terminal to original view
|
||||||
|
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||||
|
crossterm::terminal::disable_raw_mode()?;
|
||||||
|
io::stdout().execute(crossterm::cursor::Show)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,341 +1,362 @@
|
|||||||
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
|
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
|
||||||
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
use crossterm::{event, ExecutableCommand};
|
use text_generation_client::ClientError;
|
||||||
use std::io;
|
use tokio::sync::mpsc;
|
||||||
use std::time::{Duration, Instant};
|
use tui::backend::Backend;
|
||||||
use tokio::sync::mpsc::error::TryRecvError;
|
|
||||||
use tokio::sync::{broadcast, mpsc};
|
|
||||||
use tokio::time::sleep;
|
|
||||||
use tui::backend::CrosstermBackend;
|
|
||||||
use tui::layout::{Constraint, Direction, Layout};
|
use tui::layout::{Constraint, Direction, Layout};
|
||||||
use tui::style::{Color, Modifier, Style};
|
use tui::style::{Color, Modifier, Style};
|
||||||
use tui::text::{Span, Spans};
|
use tui::text::{Span, Spans};
|
||||||
use tui::widgets::{
|
use tui::widgets::{
|
||||||
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||||
};
|
};
|
||||||
use tui::{symbols, Terminal};
|
use tui::{symbols, Frame};
|
||||||
use text_generation_client::ClientError;
|
|
||||||
use crate::generation::Message;
|
struct Data {
|
||||||
|
prefill_latencies: Vec<Vec<f64>>,
|
||||||
|
prefill_throughputs: Vec<Vec<f64>>,
|
||||||
|
decode_latencies: Vec<Vec<f64>>,
|
||||||
|
decode_throughputs: Vec<Vec<f64>>,
|
||||||
|
prefill_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
decode_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Data {
|
||||||
|
fn new(n_run: usize, n_batch: usize) -> Self {
|
||||||
|
let prefill_latencies: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
let prefill_throughputs: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
|
||||||
|
let decode_latencies: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
let decode_throughputs: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
|
||||||
|
let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
||||||
|
|
||||||
|
let decode_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
prefill_latencies,
|
||||||
|
prefill_throughputs,
|
||||||
|
decode_latencies,
|
||||||
|
decode_throughputs,
|
||||||
|
prefill_batch_latency_throughput,
|
||||||
|
decode_batch_latency_throughput,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {
|
||||||
|
let latency = prefill.latency.as_millis() as f64;
|
||||||
|
self.prefill_latencies[batch_idx].push(latency);
|
||||||
|
self.prefill_throughputs[batch_idx].push(prefill.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_decode(&mut self, prefill: Decode, batch_idx: usize) {
|
||||||
|
let latency = prefill.latency.as_millis() as f64;
|
||||||
|
self.decode_latencies[batch_idx].push(latency);
|
||||||
|
self.decode_throughputs[batch_idx].push(prefill.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_batch(&mut self, batch_idx: usize) {
|
||||||
|
self.prefill_batch_latency_throughput.push((
|
||||||
|
self.prefill_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_latencies[batch_idx].len() as f64,
|
||||||
|
self.prefill_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
self.decode_batch_latency_throughput.push((
|
||||||
|
self.decode_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_latencies[batch_idx].len() as f64,
|
||||||
|
self.decode_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct UI {
|
pub(crate) struct UI {
|
||||||
pub(crate) tokenizer_name: String,
|
pub(crate) running: bool,
|
||||||
pub(crate) sequence_length: u32,
|
completed_runs: Vec<usize>,
|
||||||
pub(crate) decode_length: u32,
|
completed_batch: usize,
|
||||||
pub(crate) n_run: usize,
|
current_batch: usize,
|
||||||
pub(crate) batch_size: Vec<u32>,
|
current_tab: usize,
|
||||||
pub(crate) receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
is_error: bool,
|
||||||
pub(crate) shutdown_sender: broadcast::Sender<()>,
|
data: Data,
|
||||||
pub(crate) _shutdown_guard_sender: mpsc::Sender<()>,
|
tokenizer_name: String,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_run: usize,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UI {
|
impl UI {
|
||||||
pub async fn draw(mut self) -> Result<(), crossterm::ErrorKind> {
|
pub(crate) fn new(
|
||||||
crossterm::terminal::enable_raw_mode()?;
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
tokenizer_name: String,
|
||||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_run: usize,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
) -> Self {
|
||||||
|
let data = Data::new(n_run, batch_size.len());
|
||||||
|
let current_tab = 0;
|
||||||
|
|
||||||
let mut current_tab_idx = 0;
|
let completed_runs: Vec<usize> = (0..batch_size.len()).map(|_| 0).collect();
|
||||||
|
let completed_batch = 0;
|
||||||
|
let current_batch = 0;
|
||||||
|
let is_error = false;
|
||||||
|
|
||||||
let mut prefill_latencies: Vec<Vec<f64>> = (0..self.batch_size.len())
|
Self {
|
||||||
.map(|_| Vec::with_capacity(self.n_run))
|
running: true,
|
||||||
.collect();
|
completed_runs,
|
||||||
let mut prefill_throughputs: Vec<Vec<f64>> = (0..self.batch_size.len())
|
completed_batch,
|
||||||
.map(|_| Vec::with_capacity(self.n_run))
|
current_batch,
|
||||||
.collect();
|
current_tab,
|
||||||
|
is_error,
|
||||||
|
data,
|
||||||
|
tokenizer_name,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_run,
|
||||||
|
batch_size,
|
||||||
|
receiver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut decode_latencies: Vec<Vec<f64>> = (0..self.batch_size.len())
|
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
|
||||||
.map(|_| Vec::with_capacity(self.n_run))
|
match key_event {
|
||||||
.collect();
|
KeyEvent {
|
||||||
let mut decode_throughputs: Vec<Vec<f64>> = (0..self.batch_size.len())
|
code: KeyCode::Right,
|
||||||
.map(|_| Vec::with_capacity(self.n_run))
|
..
|
||||||
.collect();
|
} => {
|
||||||
|
self.current_tab = (self.current_tab + 1) % self.batch_size.len();
|
||||||
let mut prefill_batch_latency_throughput: Vec<(f64, f64)> =
|
|
||||||
Vec::with_capacity(self.batch_size.len());
|
|
||||||
|
|
||||||
let mut decode_batch_latency_throughput: Vec<(f64, f64)> =
|
|
||||||
Vec::with_capacity(self.batch_size.len());
|
|
||||||
|
|
||||||
let mut completed_runs: Vec<usize> = (0..self.batch_size.len()).map(|_| 0).collect();
|
|
||||||
let mut completed_batch = 0;
|
|
||||||
let mut current_batch_idx = 0;
|
|
||||||
let mut is_error = false;
|
|
||||||
|
|
||||||
let mut terminal = {
|
|
||||||
let backend = CrosstermBackend::new(io::stdout());
|
|
||||||
Terminal::new(backend)?
|
|
||||||
};
|
|
||||||
|
|
||||||
'outer: loop {
|
|
||||||
let frame_start = Instant::now();
|
|
||||||
loop {
|
|
||||||
match self.receiver.try_recv() {
|
|
||||||
Ok(message) => match message {
|
|
||||||
Ok(message) => {
|
|
||||||
match message {
|
|
||||||
Message::Prefill(step) => {
|
|
||||||
let latency = step.latency.as_millis() as f64;
|
|
||||||
prefill_latencies[current_batch_idx].push(latency);
|
|
||||||
prefill_throughputs[current_batch_idx].push(step.throughput);
|
|
||||||
}
|
|
||||||
Message::Decode(step) => {
|
|
||||||
let latency = step.latency.as_millis() as f64;
|
|
||||||
decode_latencies[current_batch_idx].push(latency);
|
|
||||||
decode_throughputs[current_batch_idx].push(step.throughput);
|
|
||||||
}
|
|
||||||
Message::Run(_) => {
|
|
||||||
completed_runs[current_batch_idx] += 1;
|
|
||||||
}
|
|
||||||
Message::EndBatch => {
|
|
||||||
prefill_batch_latency_throughput.push((
|
|
||||||
prefill_latencies[current_batch_idx].iter().sum::<f64>()
|
|
||||||
/ completed_runs[current_batch_idx] as f64,
|
|
||||||
prefill_throughputs[current_batch_idx].iter().sum::<f64>()
|
|
||||||
/ completed_runs[current_batch_idx] as f64,
|
|
||||||
));
|
|
||||||
decode_batch_latency_throughput.push((
|
|
||||||
decode_latencies[current_batch_idx].iter().sum::<f64>()
|
|
||||||
/ completed_runs[current_batch_idx] as f64,
|
|
||||||
decode_throughputs[current_batch_idx].iter().sum::<f64>()
|
|
||||||
/ completed_runs[current_batch_idx] as f64,
|
|
||||||
));
|
|
||||||
|
|
||||||
completed_batch += 1;
|
|
||||||
if current_batch_idx < self.batch_size.len() - 1 {
|
|
||||||
current_batch_idx += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Message::Warmup => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => is_error = true
|
|
||||||
},
|
|
||||||
Err(TryRecvError::Empty) => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(TryRecvError::Disconnected) => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
KeyEvent {
|
||||||
let batch_progress =
|
code: KeyCode::Left,
|
||||||
(completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
|
..
|
||||||
let run_progress =
|
} => {
|
||||||
(completed_runs[current_batch_idx] as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
if self.current_tab > 0 {
|
||||||
|
self.current_tab -= 1;
|
||||||
terminal.draw(|f| {
|
|
||||||
// Vertical layout
|
|
||||||
let row5 = Layout::default()
|
|
||||||
.direction(Direction::Vertical)
|
|
||||||
.constraints(
|
|
||||||
[
|
|
||||||
Constraint::Length(1),
|
|
||||||
Constraint::Length(3),
|
|
||||||
Constraint::Length(3),
|
|
||||||
Constraint::Length(13),
|
|
||||||
Constraint::Min(10),
|
|
||||||
]
|
|
||||||
.as_ref(),
|
|
||||||
)
|
|
||||||
.split(f.size());
|
|
||||||
|
|
||||||
// Top row horizontal layout
|
|
||||||
let top = Layout::default()
|
|
||||||
.direction(Direction::Horizontal)
|
|
||||||
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
|
||||||
.split(row5[2]);
|
|
||||||
|
|
||||||
// Mid row horizontal layout
|
|
||||||
let mid = Layout::default()
|
|
||||||
.direction(Direction::Horizontal)
|
|
||||||
.constraints(
|
|
||||||
[
|
|
||||||
Constraint::Percentage(20),
|
|
||||||
Constraint::Percentage(30),
|
|
||||||
Constraint::Percentage(20),
|
|
||||||
Constraint::Percentage(30),
|
|
||||||
]
|
|
||||||
.as_ref(),
|
|
||||||
)
|
|
||||||
.split(row5[3]);
|
|
||||||
|
|
||||||
// Left mid row vertical layout
|
|
||||||
let prefill_text = Layout::default()
|
|
||||||
.direction(Direction::Vertical)
|
|
||||||
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
|
||||||
.split(mid[0]);
|
|
||||||
|
|
||||||
// Right mid row vertical layout
|
|
||||||
let decode_text = Layout::default()
|
|
||||||
.direction(Direction::Vertical)
|
|
||||||
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
|
||||||
.split(mid[2]);
|
|
||||||
|
|
||||||
// Bottom row horizontal layout
|
|
||||||
let bottom = Layout::default()
|
|
||||||
.direction(Direction::Horizontal)
|
|
||||||
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
|
||||||
.split(row5[4]);
|
|
||||||
|
|
||||||
// Title
|
|
||||||
let title = Block::default().borders(Borders::NONE).title(format!(
|
|
||||||
"Model: {} | Sequence Length: {} | Decode Length: {}",
|
|
||||||
self.tokenizer_name, self.sequence_length, self.decode_length
|
|
||||||
)).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White));
|
|
||||||
f.render_widget(title, row5[0]);
|
|
||||||
|
|
||||||
// Batch tabs
|
|
||||||
let titles = self
|
|
||||||
.batch_size
|
|
||||||
.iter()
|
|
||||||
.map(|b| {
|
|
||||||
Spans::from(vec![Span::styled(
|
|
||||||
format!("Batch: {b}"),
|
|
||||||
Style::default().fg(Color::White),
|
|
||||||
)])
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let tabs = Tabs::new(titles)
|
|
||||||
.block(Block::default().borders(Borders::ALL).title("Tabs"))
|
|
||||||
.select(current_tab_idx)
|
|
||||||
.style(Style::default().fg(Color::LightCyan))
|
|
||||||
.highlight_style(
|
|
||||||
Style::default()
|
|
||||||
.add_modifier(Modifier::BOLD)
|
|
||||||
.bg(Color::Black),
|
|
||||||
);
|
|
||||||
f.render_widget(tabs, row5[1]);
|
|
||||||
|
|
||||||
// Total progress bar
|
|
||||||
let batch_gauge = progress_gauge(
|
|
||||||
"Total Progress",
|
|
||||||
format!("{} / {}", completed_batch, self.batch_size.len()),
|
|
||||||
batch_progress,
|
|
||||||
Color::LightGreen,
|
|
||||||
);
|
|
||||||
f.render_widget(batch_gauge, top[0]);
|
|
||||||
|
|
||||||
// Batch progress Bar
|
|
||||||
let run_gauge = progress_gauge(
|
|
||||||
"Batch Progress",
|
|
||||||
format!("{} / {}", completed_runs[current_batch_idx], self.n_run),
|
|
||||||
run_progress,
|
|
||||||
Color::LightBlue,
|
|
||||||
);
|
|
||||||
f.render_widget(run_gauge, top[1]);
|
|
||||||
|
|
||||||
// Prefill text infos
|
|
||||||
let (prefill_latency_statics, prefill_throughput_statics) = text_info(
|
|
||||||
&mut prefill_latencies[current_tab_idx],
|
|
||||||
&prefill_throughputs[current_tab_idx],
|
|
||||||
"Prefill",
|
|
||||||
);
|
|
||||||
f.render_widget(prefill_latency_statics, prefill_text[0]);
|
|
||||||
f.render_widget(prefill_throughput_statics, prefill_text[1]);
|
|
||||||
|
|
||||||
// Prefill latency histogram
|
|
||||||
let histo_width = 7;
|
|
||||||
let bins = if mid[1].width < 2 {
|
|
||||||
0
|
|
||||||
} else {
|
} else {
|
||||||
(mid[1].width as usize - 2) / (histo_width + 1)
|
self.current_tab = self.batch_size.len() - 1;
|
||||||
}
|
|
||||||
.max(2);
|
|
||||||
|
|
||||||
let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins);
|
|
||||||
let histo_data_str: Vec<(&str, u64)> =
|
|
||||||
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
|
||||||
let prefill_histogram =
|
|
||||||
latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16);
|
|
||||||
f.render_widget(prefill_histogram, mid[1]);
|
|
||||||
|
|
||||||
// Decode text info
|
|
||||||
let (decode_latency_statics, decode_throughput_statics) = text_info(
|
|
||||||
&mut decode_latencies[current_tab_idx],
|
|
||||||
&decode_throughputs[current_tab_idx],
|
|
||||||
"Decode",
|
|
||||||
);
|
|
||||||
f.render_widget(decode_latency_statics, decode_text[0]);
|
|
||||||
f.render_widget(decode_throughput_statics, decode_text[1]);
|
|
||||||
|
|
||||||
// Decode latency histogram
|
|
||||||
let histo_data = latency_histogram_data(&decode_latencies[current_tab_idx], bins);
|
|
||||||
let histo_data_str: Vec<(&str, u64)> =
|
|
||||||
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
|
||||||
let decode_histogram =
|
|
||||||
latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16);
|
|
||||||
f.render_widget(decode_histogram, mid[3]);
|
|
||||||
|
|
||||||
// Prefill latency/throughput chart
|
|
||||||
let prefill_latency_throughput_chart = latency_throughput_chart(
|
|
||||||
&prefill_batch_latency_throughput,
|
|
||||||
&self.batch_size,
|
|
||||||
"Prefill",
|
|
||||||
);
|
|
||||||
f.render_widget(prefill_latency_throughput_chart, bottom[0]);
|
|
||||||
|
|
||||||
// Decode latency/throughput chart
|
|
||||||
let decode_latency_throughput_chart = latency_throughput_chart(
|
|
||||||
&decode_batch_latency_throughput,
|
|
||||||
&self.batch_size,
|
|
||||||
"Decode",
|
|
||||||
);
|
|
||||||
f.render_widget(decode_latency_throughput_chart, bottom[1]);
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Quit on q or CTRL+c
|
|
||||||
|
|
||||||
while event::poll(Duration::from_millis(100))? {
|
|
||||||
if let Event::Key(key) = event::read()? {
|
|
||||||
match key {
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Right,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
current_tab_idx = (current_tab_idx + 1) % self.batch_size.len();
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Left,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
if current_tab_idx > 0 {
|
|
||||||
current_tab_idx -= 1;
|
|
||||||
} else {
|
|
||||||
current_tab_idx = self.batch_size.len() - 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Char('q'),
|
|
||||||
..
|
|
||||||
}
|
|
||||||
| KeyEvent {
|
|
||||||
code: KeyCode::Char('c'),
|
|
||||||
modifiers: KeyModifiers::CONTROL,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
break 'outer;
|
|
||||||
}
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Char('q'),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| KeyEvent {
|
||||||
|
code: KeyCode::Char('c'),
|
||||||
|
modifiers: KeyModifiers::CONTROL,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.running = false;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Frame budget
|
pub(crate) fn tick(&mut self) {
|
||||||
let per_frame = Duration::from_secs(1) / 30 as u32;
|
while let Ok(message) = self.receiver.try_recv() {
|
||||||
let elapsed = frame_start.elapsed();
|
match message {
|
||||||
if per_frame > elapsed {
|
Ok(message) => match message {
|
||||||
sleep(per_frame - elapsed).await;
|
Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),
|
||||||
|
Message::Decode(step) => self.data.push_decode(step, self.current_batch),
|
||||||
|
Message::Run(_) => {
|
||||||
|
self.completed_runs[self.current_batch] += 1;
|
||||||
|
}
|
||||||
|
Message::EndBatch => {
|
||||||
|
self.data.end_batch(self.current_batch);
|
||||||
|
|
||||||
|
self.completed_batch += 1;
|
||||||
|
if self.current_batch < self.batch_size.len() - 1 {
|
||||||
|
self.current_batch += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Warmup => {}
|
||||||
|
},
|
||||||
|
Err(_) => self.is_error = true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Revert terminal to original view
|
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
||||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
let batch_progress =
|
||||||
crossterm::terminal::disable_raw_mode()?;
|
(self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||||
io::stdout().execute(crossterm::cursor::Show)?;
|
let run_progress =
|
||||||
|
(self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
||||||
|
|
||||||
let _ = self.shutdown_sender.send(());
|
// Vertical layout
|
||||||
Ok(())
|
let row5 = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints(
|
||||||
|
[
|
||||||
|
Constraint::Length(1),
|
||||||
|
Constraint::Length(3),
|
||||||
|
Constraint::Length(3),
|
||||||
|
Constraint::Length(13),
|
||||||
|
Constraint::Min(10),
|
||||||
|
]
|
||||||
|
.as_ref(),
|
||||||
|
)
|
||||||
|
.split(f.size());
|
||||||
|
|
||||||
|
// Top row horizontal layout
|
||||||
|
let top = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
|
.split(row5[2]);
|
||||||
|
|
||||||
|
// Mid row horizontal layout
|
||||||
|
let mid = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints(
|
||||||
|
[
|
||||||
|
Constraint::Percentage(20),
|
||||||
|
Constraint::Percentage(30),
|
||||||
|
Constraint::Percentage(20),
|
||||||
|
Constraint::Percentage(30),
|
||||||
|
]
|
||||||
|
.as_ref(),
|
||||||
|
)
|
||||||
|
.split(row5[3]);
|
||||||
|
|
||||||
|
// Left mid row vertical layout
|
||||||
|
let prefill_text = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
||||||
|
.split(mid[0]);
|
||||||
|
|
||||||
|
// Right mid row vertical layout
|
||||||
|
let decode_text = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
||||||
|
.split(mid[2]);
|
||||||
|
|
||||||
|
// Bottom row horizontal layout
|
||||||
|
let bottom = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
|
.split(row5[4]);
|
||||||
|
|
||||||
|
// Title
|
||||||
|
let title = Block::default()
|
||||||
|
.borders(Borders::NONE)
|
||||||
|
.title(format!(
|
||||||
|
"Model: {} | Sequence Length: {} | Decode Length: {}",
|
||||||
|
self.tokenizer_name, self.sequence_length, self.decode_length
|
||||||
|
))
|
||||||
|
.style(
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::White),
|
||||||
|
);
|
||||||
|
f.render_widget(title, row5[0]);
|
||||||
|
|
||||||
|
// Batch tabs
|
||||||
|
let titles = self
|
||||||
|
.batch_size
|
||||||
|
.iter()
|
||||||
|
.map(|b| {
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!("Batch: {b}"),
|
||||||
|
Style::default().fg(Color::White),
|
||||||
|
)])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tabs = Tabs::new(titles)
|
||||||
|
.block(Block::default().borders(Borders::ALL).title("Tabs"))
|
||||||
|
.select(self.current_tab)
|
||||||
|
.style(Style::default().fg(Color::LightCyan))
|
||||||
|
.highlight_style(
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.bg(Color::Black),
|
||||||
|
);
|
||||||
|
f.render_widget(tabs, row5[1]);
|
||||||
|
|
||||||
|
// Total progress bar
|
||||||
|
let batch_gauge = progress_gauge(
|
||||||
|
"Total Progress",
|
||||||
|
format!("{} / {}", self.completed_batch, self.batch_size.len()),
|
||||||
|
batch_progress,
|
||||||
|
Color::LightGreen,
|
||||||
|
);
|
||||||
|
f.render_widget(batch_gauge, top[0]);
|
||||||
|
|
||||||
|
// Batch progress Bar
|
||||||
|
let run_gauge = progress_gauge(
|
||||||
|
"Batch Progress",
|
||||||
|
format!(
|
||||||
|
"{} / {}",
|
||||||
|
self.completed_runs[self.current_batch], self.n_run
|
||||||
|
),
|
||||||
|
run_progress,
|
||||||
|
Color::LightBlue,
|
||||||
|
);
|
||||||
|
f.render_widget(run_gauge, top[1]);
|
||||||
|
|
||||||
|
// Prefill text infos
|
||||||
|
let (prefill_latency_statics, prefill_throughput_statics) = text_info(
|
||||||
|
&mut self.data.prefill_latencies[self.current_tab],
|
||||||
|
&self.data.prefill_throughputs[self.current_tab],
|
||||||
|
"Prefill",
|
||||||
|
);
|
||||||
|
f.render_widget(prefill_latency_statics, prefill_text[0]);
|
||||||
|
f.render_widget(prefill_throughput_statics, prefill_text[1]);
|
||||||
|
|
||||||
|
// Prefill latency histogram
|
||||||
|
let histo_width = 7;
|
||||||
|
let bins = if mid[1].width < 2 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
(mid[1].width as usize - 2) / (histo_width + 1)
|
||||||
|
}
|
||||||
|
.max(2);
|
||||||
|
|
||||||
|
let histo_data =
|
||||||
|
latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins);
|
||||||
|
let histo_data_str: Vec<(&str, u64)> =
|
||||||
|
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
||||||
|
let prefill_histogram =
|
||||||
|
latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16);
|
||||||
|
f.render_widget(prefill_histogram, mid[1]);
|
||||||
|
|
||||||
|
// Decode text info
|
||||||
|
let (decode_latency_statics, decode_throughput_statics) = text_info(
|
||||||
|
&mut self.data.decode_latencies[self.current_tab],
|
||||||
|
&self.data.decode_throughputs[self.current_tab],
|
||||||
|
"Decode",
|
||||||
|
);
|
||||||
|
f.render_widget(decode_latency_statics, decode_text[0]);
|
||||||
|
f.render_widget(decode_throughput_statics, decode_text[1]);
|
||||||
|
|
||||||
|
// Decode latency histogram
|
||||||
|
let histo_data =
|
||||||
|
latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins);
|
||||||
|
let histo_data_str: Vec<(&str, u64)> =
|
||||||
|
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
||||||
|
let decode_histogram =
|
||||||
|
latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16);
|
||||||
|
f.render_widget(decode_histogram, mid[3]);
|
||||||
|
|
||||||
|
// Prefill latency/throughput chart
|
||||||
|
let prefill_latency_throughput_chart = latency_throughput_chart(
|
||||||
|
&self.data.prefill_batch_latency_throughput,
|
||||||
|
&self.batch_size,
|
||||||
|
"Prefill",
|
||||||
|
);
|
||||||
|
f.render_widget(prefill_latency_throughput_chart, bottom[0]);
|
||||||
|
|
||||||
|
// Decode latency/throughput chart
|
||||||
|
let decode_latency_throughput_chart = latency_throughput_chart(
|
||||||
|
&self.data.decode_batch_latency_throughput,
|
||||||
|
&self.batch_size,
|
||||||
|
"Decode",
|
||||||
|
);
|
||||||
|
f.render_widget(decode_latency_throughput_chart, bottom[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user