add shutdown logic

This commit is contained in:
OlivierDehaene 2023-03-28 11:13:14 +02:00
parent c0d793d2ca
commit 681744b982
3 changed files with 56 additions and 20 deletions

View File

@ -4,18 +4,18 @@ use std::time::{Duration, Instant};
use tokenizers::{Tokenizer, TruncationDirection}; use tokenizers::{Tokenizer, TruncationDirection};
use tokio::time; use tokio::time;
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError}; use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
use tokio::sync::mpsc; use tokio::sync::{mpsc, broadcast};
use crate::ui::UI; use crate::ui::UI;
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug)] #[derive(Debug, Clone)]
pub(crate) enum Step { pub(crate) enum Step {
Prefill, Prefill,
Decode, Decode,
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub(crate) struct Run { pub(crate) struct Run {
step: Step, step: Step,
batch_size: u32, batch_size: u32,
@ -29,39 +29,56 @@ pub async fn run(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
runs: usize, n_runs: usize,
mut client: ShardedClient, mut client: ShardedClient,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let (sender, receiver) = mpsc::channel(8); let (run_sender, run_receiver) = mpsc::channel(8);
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
tokio::spawn( tokio::spawn(
UI { UI {
n_run: runs, n_run: n_runs,
n_batch: batch_size.len(), n_batch: batch_size.len(),
n_batch_done: 0, n_batch_done: 0,
run_receiver: receiver, run_receiver,
shutdown_sender,
}.draw() }.draw()
); );
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
let sequence = create_sequence(sequence_length, tokenizer); let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size { for b in batch_size {
for _ in 0..runs { for _ in 0..n_runs {
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; let (run, decode_batch) = tokio::select! {
sender.send(run).await.unwrap(); res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(());
}
};
run_sender.send(run.clone()).await.unwrap();
runs.push(run);
let run = run_decode(decode_batch, sequence_length, &mut client).await?; let run = tokio::select! {
sender.send(run).await.unwrap(); res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(());
}
};
tokio::time::sleep(Duration::from_millis(100)).await; run_sender.send(run.clone()).await.unwrap();
runs.push(run);
} }
} }
drop(sender); // Shutdown UI by dropping run sender triggering the UI to exit its rendering loop
drop(run_sender);
tokio::time::sleep(Duration::from_millis(100)).await; // Wait for UI to shutdown
let _ = shutdown_receiver.recv().await;
Ok(()) Ok(())
} }

View File

@ -19,7 +19,7 @@ struct Args {
sequence_length: u32, sequence_length: u32,
#[clap(default_value = "100", long, env)] #[clap(default_value = "100", long, env)]
decode_length: u32, decode_length: u32,
#[clap(default_value = "10", long, env)] #[clap(default_value = "2", long, env)]
runs: usize, runs: usize,
#[clap(default_value = "/tmp/text-generation-0", long, env)] #[clap(default_value = "/tmp/text-generation-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,

View File

@ -13,14 +13,15 @@ use tui::style::{Color, Style};
use tui::text::{Span, Spans}; use tui::text::{Span, Spans};
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph}; use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
use tui::Terminal; use tui::Terminal;
use tokio::sync::mpsc::Receiver; use tokio::sync::{mpsc, broadcast};
use crate::{Run, Step}; use crate::{Run, Step};
pub(crate) struct UI { pub(crate) struct UI {
pub(crate) n_run: usize, pub(crate) n_run: usize,
pub(crate) n_batch: usize, pub(crate) n_batch: usize,
pub(crate) n_batch_done: usize, pub(crate) n_batch_done: usize,
pub(crate) run_receiver: Receiver<Run>, pub(crate) run_receiver: mpsc::Receiver<Run>,
pub(crate) shutdown_sender: broadcast::Sender<()>,
} }
impl UI { impl UI {
@ -175,6 +176,23 @@ impl UI {
f.render_widget(decode_throughput_statics, decode_text[1]); f.render_widget(decode_throughput_statics, decode_text[1]);
})?; })?;
while crossterm::event::poll(Duration::from_secs(0))? {
match crossterm::event::read()? {
Event::Key(KeyEvent {
code: KeyCode::Char('q'),
..
})
| Event::Key(KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
}) => {
break 'outer;
}
_ => (),
}
}
let per_frame = Duration::from_secs(1) / 30 as u32; let per_frame = Duration::from_secs(1) / 30 as u32;
let elapsed = frame_start.elapsed(); let elapsed = frame_start.elapsed();
if per_frame > elapsed { if per_frame > elapsed {
@ -182,10 +200,11 @@ impl UI {
} }
} }
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?; crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?; io::stdout().execute(crossterm::cursor::Show)?;
let _ = self.shutdown_sender.send(());
Ok(()) Ok(())
} }
} }