mod app; mod event; mod generation; mod table; mod utils; use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; use text_generation_client::{NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; use tui::Terminal; /// Run benchmarking app #[allow(clippy::too_many_arguments)] pub async fn run( tokenizer_name: String, tokenizer: Tokenizer, batch_size: Vec, sequence_length: u32, decode_length: u32, n_runs: usize, warmups: usize, temperature: Option, top_k: Option, top_p: Option, typical_p: Option, repetition_penalty: Option, watermark: bool, do_sample: bool, client: ShardedClient, ) -> Result<(), crossterm::ErrorKind> { let parameters = NextTokenChooserParameters { temperature: temperature.unwrap_or(1.0), top_k: top_k.unwrap_or(0), top_p: top_p.unwrap_or(1.0), typical_p: typical_p.unwrap_or(1.0), do_sample, seed: 0, repetition_penalty: repetition_penalty.unwrap_or(1.0), watermark, }; // Initialize terminal properties crossterm::terminal::enable_raw_mode()?; io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; io::stdout().execute(crossterm::cursor::Hide)?; // Initialize terminal let mut terminal = { let backend = CrosstermBackend::new(io::stdout()); Terminal::new(backend)? }; // Create message channel between generation_task and app let (run_sender, run_receiver) = mpsc::channel(8); // Crossterm event channel let (event_sender, mut event_receiver) = mpsc::channel(8); // Shutdown channel to terminate tasks let (shutdown_sender, _) = broadcast::channel(1); // Channel to check if tasks terminated let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1); // Create generation task tokio::spawn(generation::generation_task( tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender, shutdown_sender.subscribe(), shutdown_guard_sender.clone(), )); // Create event task tokio::spawn(event::terminal_event_task( 250, event_sender, shutdown_sender.subscribe(), shutdown_guard_sender.clone(), )); // Drop our end of shutdown sender drop(shutdown_guard_sender); // Create App let mut app = App::new( run_receiver, tokenizer_name.clone(), sequence_length, decode_length, n_runs, batch_size, ); while app.running { // Draw frame terminal.draw(|frame| app.render(frame))?; // Await a new event from event handling task match event_receiver.recv().await { None => break, // Update app state Some(event) => match event { Event::Tick => app.tick(), Event::Key(key_event) => app.handle_key_event(key_event), _ => {} }, } } // Ask tasks to shutdown let _ = shutdown_sender.send(()); // Wait for tasks to shutdown let _ = shutdown_guard_receiver.recv().await; // Revert terminal to original view io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; crossterm::terminal::disable_raw_mode()?; io::stdout().execute(crossterm::cursor::Show)?; let parameters_table = table::parameters_table( tokenizer_name, sequence_length, decode_length, n_runs, warmups, temperature, top_k, top_p, typical_p, repetition_penalty, watermark, do_sample, ); println!("\n{parameters_table}\n"); let latency_table = table::latency_table(&app.data); println!("\n{latency_table}\n"); let throughput_table = table::throughput_table(&app.data); println!("\n{throughput_table}\n"); Ok(()) }