mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add shutdown logic
This commit is contained in:
parent
c0d793d2ca
commit
681744b982
@ -4,18 +4,18 @@ use std::time::{Duration, Instant};
|
|||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
|
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::{mpsc, broadcast};
|
||||||
use crate::ui::UI;
|
use crate::ui::UI;
|
||||||
|
|
||||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) enum Step {
|
pub(crate) enum Step {
|
||||||
Prefill,
|
Prefill,
|
||||||
Decode,
|
Decode,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Run {
|
pub(crate) struct Run {
|
||||||
step: Step,
|
step: Step,
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
@ -29,39 +29,56 @@ pub async fn run(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
runs: usize,
|
n_runs: usize,
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (sender, receiver) = mpsc::channel(8);
|
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||||
|
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
|
||||||
|
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
UI {
|
UI {
|
||||||
n_run: runs,
|
n_run: n_runs,
|
||||||
n_batch: batch_size.len(),
|
n_batch: batch_size.len(),
|
||||||
n_batch_done: 0,
|
n_batch_done: 0,
|
||||||
run_receiver: receiver,
|
run_receiver,
|
||||||
|
shutdown_sender,
|
||||||
}.draw()
|
}.draw()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
for _ in 0..runs {
|
for _ in 0..n_runs {
|
||||||
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
let (run, decode_batch) = tokio::select! {
|
||||||
sender.send(run).await.unwrap();
|
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
|
||||||
|
_ = shutdown_receiver.recv() => {
|
||||||
|
tracing::info!("shutdown");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
run_sender.send(run.clone()).await.unwrap();
|
||||||
|
runs.push(run);
|
||||||
|
|
||||||
let run = run_decode(decode_batch, sequence_length, &mut client).await?;
|
let run = tokio::select! {
|
||||||
sender.send(run).await.unwrap();
|
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
||||||
|
_ = shutdown_receiver.recv() => {
|
||||||
|
tracing::info!("shutdown");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
run_sender.send(run.clone()).await.unwrap();
|
||||||
|
runs.push(run);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drop(sender);
|
// Shutdown UI by dropping run sender triggering the UI to exit its rendering loop
|
||||||
|
drop(run_sender);
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
// Wait for UI to shutdown
|
||||||
|
let _ = shutdown_receiver.recv().await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ struct Args {
|
|||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
#[clap(default_value = "100", long, env)]
|
#[clap(default_value = "100", long, env)]
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
#[clap(default_value = "10", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
runs: usize,
|
runs: usize,
|
||||||
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
|
@ -13,14 +13,15 @@ use tui::style::{Color, Style};
|
|||||||
use tui::text::{Span, Spans};
|
use tui::text::{Span, Spans};
|
||||||
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
||||||
use tui::Terminal;
|
use tui::Terminal;
|
||||||
use tokio::sync::mpsc::Receiver;
|
use tokio::sync::{mpsc, broadcast};
|
||||||
use crate::{Run, Step};
|
use crate::{Run, Step};
|
||||||
|
|
||||||
pub(crate) struct UI {
|
pub(crate) struct UI {
|
||||||
pub(crate) n_run: usize,
|
pub(crate) n_run: usize,
|
||||||
pub(crate) n_batch: usize,
|
pub(crate) n_batch: usize,
|
||||||
pub(crate) n_batch_done: usize,
|
pub(crate) n_batch_done: usize,
|
||||||
pub(crate) run_receiver: Receiver<Run>,
|
pub(crate) run_receiver: mpsc::Receiver<Run>,
|
||||||
|
pub(crate) shutdown_sender: broadcast::Sender<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UI {
|
impl UI {
|
||||||
@ -175,6 +176,23 @@ impl UI {
|
|||||||
f.render_widget(decode_throughput_statics, decode_text[1]);
|
f.render_widget(decode_throughput_statics, decode_text[1]);
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
while crossterm::event::poll(Duration::from_secs(0))? {
|
||||||
|
match crossterm::event::read()? {
|
||||||
|
Event::Key(KeyEvent {
|
||||||
|
code: KeyCode::Char('q'),
|
||||||
|
..
|
||||||
|
})
|
||||||
|
| Event::Key(KeyEvent {
|
||||||
|
code: KeyCode::Char('c'),
|
||||||
|
modifiers: KeyModifiers::CONTROL,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let per_frame = Duration::from_secs(1) / 30 as u32;
|
let per_frame = Duration::from_secs(1) / 30 as u32;
|
||||||
let elapsed = frame_start.elapsed();
|
let elapsed = frame_start.elapsed();
|
||||||
if per_frame > elapsed {
|
if per_frame > elapsed {
|
||||||
@ -182,10 +200,11 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||||
crossterm::terminal::disable_raw_mode()?;
|
crossterm::terminal::disable_raw_mode()?;
|
||||||
io::stdout().execute(crossterm::cursor::Show)?;
|
io::stdout().execute(crossterm::cursor::Show)?;
|
||||||
|
|
||||||
|
let _ = self.shutdown_sender.send(());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user