2023-03-30 13:26:27 +00:00
|
|
|
mod app;
|
|
|
|
mod event;
|
|
|
|
mod generation;
|
2023-05-25 11:38:36 +00:00
|
|
|
mod table;
|
2023-03-30 13:26:27 +00:00
|
|
|
mod utils;
|
|
|
|
|
|
|
|
use crate::app::App;
|
|
|
|
use crate::event::Event;
|
|
|
|
use crossterm::ExecutableCommand;
|
|
|
|
use std::io;
|
2024-06-04 13:56:56 +00:00
|
|
|
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
2023-03-30 13:26:27 +00:00
|
|
|
use tokenizers::Tokenizer;
|
|
|
|
use tokio::sync::{broadcast, mpsc};
|
|
|
|
use tui::backend::CrosstermBackend;
|
|
|
|
use tui::Terminal;
|
|
|
|
|
|
|
|
/// Run benchmarking app
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
|
pub async fn run(
|
|
|
|
tokenizer_name: String,
|
|
|
|
tokenizer: Tokenizer,
|
|
|
|
batch_size: Vec<u32>,
|
|
|
|
sequence_length: u32,
|
|
|
|
decode_length: u32,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_n_tokens: Option<u32>,
|
2023-03-30 13:26:27 +00:00
|
|
|
n_runs: usize,
|
|
|
|
warmups: usize,
|
2023-05-25 11:38:36 +00:00
|
|
|
temperature: Option<f32>,
|
|
|
|
top_k: Option<u32>,
|
|
|
|
top_p: Option<f32>,
|
|
|
|
typical_p: Option<f32>,
|
|
|
|
repetition_penalty: Option<f32>,
|
2024-02-08 17:41:25 +00:00
|
|
|
frequency_penalty: Option<f32>,
|
2023-05-25 11:38:36 +00:00
|
|
|
watermark: bool,
|
|
|
|
do_sample: bool,
|
2023-03-30 13:26:27 +00:00
|
|
|
client: ShardedClient,
|
2023-09-27 08:40:18 +00:00
|
|
|
) -> Result<(), std::io::Error> {
|
2023-05-25 11:38:36 +00:00
|
|
|
let parameters = NextTokenChooserParameters {
|
|
|
|
temperature: temperature.unwrap_or(1.0),
|
|
|
|
top_k: top_k.unwrap_or(0),
|
|
|
|
top_p: top_p.unwrap_or(1.0),
|
|
|
|
typical_p: typical_p.unwrap_or(1.0),
|
|
|
|
do_sample,
|
|
|
|
seed: 0,
|
|
|
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
2024-02-08 17:41:25 +00:00
|
|
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
2023-05-25 11:38:36 +00:00
|
|
|
watermark,
|
2024-02-15 09:28:10 +00:00
|
|
|
grammar: String::new(),
|
|
|
|
grammar_type: GrammarType::None as i32,
|
2023-05-25 11:38:36 +00:00
|
|
|
};
|
|
|
|
|
2023-03-30 13:26:27 +00:00
|
|
|
// Initialize terminal properties
|
|
|
|
crossterm::terminal::enable_raw_mode()?;
|
|
|
|
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
|
|
|
io::stdout().execute(crossterm::cursor::Hide)?;
|
|
|
|
|
|
|
|
// Initialize terminal
|
|
|
|
let mut terminal = {
|
|
|
|
let backend = CrosstermBackend::new(io::stdout());
|
|
|
|
Terminal::new(backend)?
|
|
|
|
};
|
|
|
|
|
|
|
|
// Create message channel between generation_task and app
|
|
|
|
let (run_sender, run_receiver) = mpsc::channel(8);
|
|
|
|
// Crossterm event channel
|
|
|
|
let (event_sender, mut event_receiver) = mpsc::channel(8);
|
|
|
|
// Shutdown channel to terminate tasks
|
|
|
|
let (shutdown_sender, _) = broadcast::channel(1);
|
|
|
|
// Channel to check if tasks terminated
|
|
|
|
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
|
|
|
|
|
|
|
// Create generation task
|
|
|
|
tokio::spawn(generation::generation_task(
|
|
|
|
tokenizer,
|
|
|
|
batch_size.clone(),
|
|
|
|
sequence_length,
|
|
|
|
decode_length,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_n_tokens,
|
2023-03-30 13:26:27 +00:00
|
|
|
n_runs,
|
|
|
|
warmups,
|
2023-05-25 11:38:36 +00:00
|
|
|
parameters,
|
2023-03-30 13:26:27 +00:00
|
|
|
client,
|
|
|
|
run_sender,
|
|
|
|
shutdown_sender.subscribe(),
|
|
|
|
shutdown_guard_sender.clone(),
|
|
|
|
));
|
|
|
|
|
|
|
|
// Create event task
|
|
|
|
tokio::spawn(event::terminal_event_task(
|
|
|
|
250,
|
|
|
|
event_sender,
|
|
|
|
shutdown_sender.subscribe(),
|
|
|
|
shutdown_guard_sender.clone(),
|
|
|
|
));
|
|
|
|
|
|
|
|
// Drop our end of shutdown sender
|
|
|
|
drop(shutdown_guard_sender);
|
|
|
|
|
|
|
|
// Create App
|
|
|
|
let mut app = App::new(
|
|
|
|
run_receiver,
|
2023-05-25 11:38:36 +00:00
|
|
|
tokenizer_name.clone(),
|
2023-03-30 13:26:27 +00:00
|
|
|
sequence_length,
|
|
|
|
decode_length,
|
|
|
|
n_runs,
|
|
|
|
batch_size,
|
|
|
|
);
|
|
|
|
|
|
|
|
while app.running {
|
|
|
|
// Draw frame
|
|
|
|
terminal.draw(|frame| app.render(frame))?;
|
|
|
|
|
|
|
|
// Await a new event from event handling task
|
|
|
|
match event_receiver.recv().await {
|
|
|
|
None => break,
|
|
|
|
// Update app state
|
|
|
|
Some(event) => match event {
|
|
|
|
Event::Tick => app.tick(),
|
|
|
|
Event::Key(key_event) => app.handle_key_event(key_event),
|
|
|
|
_ => {}
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ask tasks to shutdown
|
|
|
|
let _ = shutdown_sender.send(());
|
|
|
|
// Wait for tasks to shutdown
|
|
|
|
let _ = shutdown_guard_receiver.recv().await;
|
|
|
|
|
|
|
|
// Revert terminal to original view
|
|
|
|
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
|
|
|
crossterm::terminal::disable_raw_mode()?;
|
|
|
|
io::stdout().execute(crossterm::cursor::Show)?;
|
|
|
|
|
2023-05-25 11:38:36 +00:00
|
|
|
let parameters_table = table::parameters_table(
|
|
|
|
tokenizer_name,
|
|
|
|
sequence_length,
|
|
|
|
decode_length,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_n_tokens,
|
2023-05-25 11:38:36 +00:00
|
|
|
n_runs,
|
|
|
|
warmups,
|
|
|
|
temperature,
|
|
|
|
top_k,
|
|
|
|
top_p,
|
|
|
|
typical_p,
|
|
|
|
repetition_penalty,
|
2024-02-08 17:41:25 +00:00
|
|
|
frequency_penalty,
|
2023-05-25 11:38:36 +00:00
|
|
|
watermark,
|
|
|
|
do_sample,
|
|
|
|
);
|
|
|
|
println!("\n{parameters_table}\n");
|
|
|
|
|
|
|
|
let latency_table = table::latency_table(&app.data);
|
|
|
|
println!("\n{latency_table}\n");
|
|
|
|
|
|
|
|
let throughput_table = table::throughput_table(&app.data);
|
|
|
|
println!("\n{throughput_table}\n");
|
|
|
|
|
2023-03-30 13:26:27 +00:00
|
|
|
Ok(())
|
|
|
|
}
|