From 681744b98249aa1a089990aebd6ece04e9cfa431 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 28 Mar 2023 11:13:14 +0200 Subject: [PATCH] add shutdown logic --- benchmark/src/lib.rs | 49 +++++++++++++++++++++++++++++-------------- benchmark/src/main.rs | 2 +- benchmark/src/ui.rs | 25 +++++++++++++++++++--- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 52c0c4bc..110e1f6a 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -4,18 +4,18 @@ use std::time::{Duration, Instant}; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::time; use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, broadcast}; use crate::ui::UI; const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum Step { Prefill, Decode, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct Run { step: Step, batch_size: u32, @@ -29,39 +29,56 @@ pub async fn run( batch_size: Vec, sequence_length: u32, decode_length: u32, - runs: usize, + n_runs: usize, mut client: ShardedClient, ) -> Result<(), Box> { - let (sender, receiver) = mpsc::channel(8); + let (run_sender, run_receiver) = mpsc::channel(8); + let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1); tokio::spawn( UI { - n_run: runs, + n_run: n_runs, n_batch: batch_size.len(), n_batch_done: 0, - run_receiver: receiver, + run_receiver, + shutdown_sender, }.draw() ); + let mut runs = Vec::with_capacity(batch_size.len() * n_runs); let sequence = create_sequence(sequence_length, tokenizer); - for b in batch_size { - for _ in 0..runs { - let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; - sender.send(run).await.unwrap(); + for _ in 0..n_runs { + let (run, decode_batch) = tokio::select! { + res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?, + _ = shutdown_receiver.recv() => { + tracing::info!("shutdown"); + return Ok(()); + } + }; + run_sender.send(run.clone()).await.unwrap(); + runs.push(run); - let run = run_decode(decode_batch, sequence_length, &mut client).await?; - sender.send(run).await.unwrap(); + let run = tokio::select! { + res = run_decode(decode_batch, sequence_length, &mut client) => res?, + _ = shutdown_receiver.recv() => { + tracing::info!("shutdown"); + return Ok(()); + } + }; - tokio::time::sleep(Duration::from_millis(100)).await; + run_sender.send(run.clone()).await.unwrap(); + runs.push(run); } } - drop(sender); + // Shutdown UI by dropping run sender triggering the UI to exit its rendering loop + drop(run_sender); - tokio::time::sleep(Duration::from_millis(100)).await; + // Wait for UI to shutdown + let _ = shutdown_receiver.recv().await; Ok(()) } diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index fff8ad82..4ec96586 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -19,7 +19,7 @@ struct Args { sequence_length: u32, #[clap(default_value = "100", long, env)] decode_length: u32, - #[clap(default_value = "10", long, env)] + #[clap(default_value = "2", long, env)] runs: usize, #[clap(default_value = "/tmp/text-generation-0", long, env)] master_shard_uds_path: String, diff --git a/benchmark/src/ui.rs b/benchmark/src/ui.rs index e775520c..5c0915d8 100644 --- a/benchmark/src/ui.rs +++ b/benchmark/src/ui.rs @@ -13,14 +13,15 @@ use tui::style::{Color, Style}; use tui::text::{Span, Spans}; use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph}; use tui::Terminal; -use tokio::sync::mpsc::Receiver; +use tokio::sync::{mpsc, broadcast}; use crate::{Run, Step}; pub(crate) struct UI { pub(crate) n_run: usize, pub(crate) n_batch: usize, pub(crate) n_batch_done: usize, - pub(crate) run_receiver: Receiver, + pub(crate) run_receiver: mpsc::Receiver, + pub(crate) shutdown_sender: broadcast::Sender<()>, } impl UI { @@ -175,6 +176,23 @@ impl UI { f.render_widget(decode_throughput_statics, decode_text[1]); })?; + while crossterm::event::poll(Duration::from_secs(0))? { + match crossterm::event::read()? { + Event::Key(KeyEvent { + code: KeyCode::Char('q'), + .. + }) + | Event::Key(KeyEvent { + code: KeyCode::Char('c'), + modifiers: KeyModifiers::CONTROL, + .. + }) => { + break 'outer; + } + _ => (), + } + } + let per_frame = Duration::from_secs(1) / 30 as u32; let elapsed = frame_start.elapsed(); if per_frame > elapsed { @@ -182,10 +200,11 @@ impl UI { } } - io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; crossterm::terminal::disable_raw_mode()?; io::stdout().execute(crossterm::cursor::Show)?; + + let _ = self.shutdown_sender.send(()); Ok(()) } }