add shutdown logic

This commit is contained in:
OlivierDehaene 2023-03-28 11:13:14 +02:00
parent c0d793d2ca
commit 681744b982
3 changed files with 56 additions and 20 deletions

View File

@ -4,18 +4,18 @@ use std::time::{Duration, Instant};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::time;
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
use tokio::sync::mpsc;
use tokio::sync::{mpsc, broadcast};
use crate::ui::UI;
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) enum Step {
Prefill,
Decode,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct Run {
step: Step,
batch_size: u32,
@ -29,39 +29,56 @@ pub async fn run(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
runs: usize,
n_runs: usize,
mut client: ShardedClient,
) -> Result<(), Box<dyn std::error::Error>> {
let (sender, receiver) = mpsc::channel(8);
let (run_sender, run_receiver) = mpsc::channel(8);
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
tokio::spawn(
UI {
n_run: runs,
n_run: n_runs,
n_batch: batch_size.len(),
n_batch_done: 0,
run_receiver: receiver,
run_receiver,
shutdown_sender,
}.draw()
);
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
for _ in 0..runs {
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
sender.send(run).await.unwrap();
for _ in 0..n_runs {
let (run, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(());
}
};
run_sender.send(run.clone()).await.unwrap();
runs.push(run);
let run = run_decode(decode_batch, sequence_length, &mut client).await?;
sender.send(run).await.unwrap();
let run = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
tracing::info!("shutdown");
return Ok(());
}
};
tokio::time::sleep(Duration::from_millis(100)).await;
run_sender.send(run.clone()).await.unwrap();
runs.push(run);
}
}
drop(sender);
// Shutdown UI by dropping run sender triggering the UI to exit its rendering loop
drop(run_sender);
tokio::time::sleep(Duration::from_millis(100)).await;
// Wait for UI to shutdown
let _ = shutdown_receiver.recv().await;
Ok(())
}

View File

@ -19,7 +19,7 @@ struct Args {
sequence_length: u32,
#[clap(default_value = "100", long, env)]
decode_length: u32,
#[clap(default_value = "10", long, env)]
#[clap(default_value = "2", long, env)]
runs: usize,
#[clap(default_value = "/tmp/text-generation-0", long, env)]
master_shard_uds_path: String,

View File

@ -13,14 +13,15 @@ use tui::style::{Color, Style};
use tui::text::{Span, Spans};
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
use tui::Terminal;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{mpsc, broadcast};
use crate::{Run, Step};
pub(crate) struct UI {
pub(crate) n_run: usize,
pub(crate) n_batch: usize,
pub(crate) n_batch_done: usize,
pub(crate) run_receiver: Receiver<Run>,
pub(crate) run_receiver: mpsc::Receiver<Run>,
pub(crate) shutdown_sender: broadcast::Sender<()>,
}
impl UI {
@ -175,6 +176,23 @@ impl UI {
f.render_widget(decode_throughput_statics, decode_text[1]);
})?;
while crossterm::event::poll(Duration::from_secs(0))? {
match crossterm::event::read()? {
Event::Key(KeyEvent {
code: KeyCode::Char('q'),
..
})
| Event::Key(KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
}) => {
break 'outer;
}
_ => (),
}
}
let per_frame = Duration::from_secs(1) / 30 as u32;
let elapsed = frame_start.elapsed();
if per_frame > elapsed {
@ -182,10 +200,11 @@ impl UI {
}
}
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
let _ = self.shutdown_sender.send(());
Ok(())
}
}