mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add shutdown logic
This commit is contained in:
parent
c0d793d2ca
commit
681744b982
@ -4,18 +4,18 @@ use std::time::{Duration, Instant};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokio::time;
|
||||
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, broadcast};
|
||||
use crate::ui::UI;
|
||||
|
||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum Step {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Run {
|
||||
step: Step,
|
||||
batch_size: u32,
|
||||
@ -29,39 +29,56 @@ pub async fn run(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
runs: usize,
|
||||
n_runs: usize,
|
||||
mut client: ShardedClient,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (sender, receiver) = mpsc::channel(8);
|
||||
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
|
||||
|
||||
|
||||
tokio::spawn(
|
||||
UI {
|
||||
n_run: runs,
|
||||
n_run: n_runs,
|
||||
n_batch: batch_size.len(),
|
||||
n_batch_done: 0,
|
||||
run_receiver: receiver,
|
||||
run_receiver,
|
||||
shutdown_sender,
|
||||
}.draw()
|
||||
);
|
||||
|
||||
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
|
||||
|
||||
for b in batch_size {
|
||||
for _ in 0..runs {
|
||||
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||
sender.send(run).await.unwrap();
|
||||
for _ in 0..n_runs {
|
||||
let (run, decode_batch) = tokio::select! {
|
||||
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
tracing::info!("shutdown");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
run_sender.send(run.clone()).await.unwrap();
|
||||
runs.push(run);
|
||||
|
||||
let run = run_decode(decode_batch, sequence_length, &mut client).await?;
|
||||
sender.send(run).await.unwrap();
|
||||
let run = tokio::select! {
|
||||
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
tracing::info!("shutdown");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
run_sender.send(run.clone()).await.unwrap();
|
||||
runs.push(run);
|
||||
}
|
||||
}
|
||||
|
||||
drop(sender);
|
||||
// Shutdown UI by dropping run sender triggering the UI to exit its rendering loop
|
||||
drop(run_sender);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
// Wait for UI to shutdown
|
||||
let _ = shutdown_receiver.recv().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ struct Args {
|
||||
sequence_length: u32,
|
||||
#[clap(default_value = "100", long, env)]
|
||||
decode_length: u32,
|
||||
#[clap(default_value = "10", long, env)]
|
||||
#[clap(default_value = "2", long, env)]
|
||||
runs: usize,
|
||||
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
|
@ -13,14 +13,15 @@ use tui::style::{Color, Style};
|
||||
use tui::text::{Span, Spans};
|
||||
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
||||
use tui::Terminal;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use tokio::sync::{mpsc, broadcast};
|
||||
use crate::{Run, Step};
|
||||
|
||||
pub(crate) struct UI {
|
||||
pub(crate) n_run: usize,
|
||||
pub(crate) n_batch: usize,
|
||||
pub(crate) n_batch_done: usize,
|
||||
pub(crate) run_receiver: Receiver<Run>,
|
||||
pub(crate) run_receiver: mpsc::Receiver<Run>,
|
||||
pub(crate) shutdown_sender: broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
impl UI {
|
||||
@ -175,6 +176,23 @@ impl UI {
|
||||
f.render_widget(decode_throughput_statics, decode_text[1]);
|
||||
})?;
|
||||
|
||||
while crossterm::event::poll(Duration::from_secs(0))? {
|
||||
match crossterm::event::read()? {
|
||||
Event::Key(KeyEvent {
|
||||
code: KeyCode::Char('q'),
|
||||
..
|
||||
})
|
||||
| Event::Key(KeyEvent {
|
||||
code: KeyCode::Char('c'),
|
||||
modifiers: KeyModifiers::CONTROL,
|
||||
..
|
||||
}) => {
|
||||
break 'outer;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
let per_frame = Duration::from_secs(1) / 30 as u32;
|
||||
let elapsed = frame_start.elapsed();
|
||||
if per_frame > elapsed {
|
||||
@ -182,10 +200,11 @@ impl UI {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||
crossterm::terminal::disable_raw_mode()?;
|
||||
io::stdout().execute(crossterm::cursor::Show)?;
|
||||
|
||||
let _ = self.shutdown_sender.send(());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user