mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
improvements
This commit is contained in:
parent
383619bd7f
commit
1c5d526943
@ -44,6 +44,7 @@ pub(crate) enum Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
tokenizer_name: String,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
@ -57,6 +58,9 @@ pub async fn run(
|
|||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
UI {
|
UI {
|
||||||
|
tokenizer_name,
|
||||||
|
decode_length,
|
||||||
|
sequence_length,
|
||||||
n_run: n_runs,
|
n_run: n_runs,
|
||||||
batch_size: batch_size.clone(),
|
batch_size: batch_size.clone(),
|
||||||
receiver: ui_receiver,
|
receiver: ui_receiver,
|
||||||
@ -68,22 +72,22 @@ pub async fn run(
|
|||||||
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
|
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
for _ in 0..warmups {
|
|
||||||
let (_, decode_batch) = tokio::select! {
|
|
||||||
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
|
|
||||||
_ = shutdown_receiver.recv() => {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let _ = tokio::select! {
|
|
||||||
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
|
||||||
_ = shutdown_receiver.recv() => {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
|
for _ in 0..warmups {
|
||||||
|
let (_, decode_batch) = tokio::select! {
|
||||||
|
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
|
||||||
|
_ = shutdown_receiver.recv() => {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let _ = tokio::select! {
|
||||||
|
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
||||||
|
_ = shutdown_receiver.recv() => {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
for _ in 0..n_runs {
|
for _ in 0..n_runs {
|
||||||
let (prefill, decode_batch) = tokio::select! {
|
let (prefill, decode_batch) = tokio::select! {
|
||||||
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
|
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
|
||||||
|
@ -15,13 +15,13 @@ struct Args {
|
|||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(default_value = "1", long, env)]
|
#[clap(default_value = "1", long, env)]
|
||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
#[clap(default_value = "12", long, env)]
|
|
||||||
sequence_length: u32,
|
|
||||||
#[clap(default_value = "10", long, env)]
|
#[clap(default_value = "10", long, env)]
|
||||||
|
sequence_length: u32,
|
||||||
|
#[clap(default_value = "64", long, env)]
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
#[clap(default_value = "10", long, env)]
|
#[clap(default_value = "10", long, env)]
|
||||||
runs: usize,
|
runs: usize,
|
||||||
#[clap(default_value = "0", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
@ -74,12 +74,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache()
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
text_generation_benchmark::run(
|
text_generation_benchmark::run(
|
||||||
|
tokenizer_name,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch_size,
|
batch_size,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
|
@ -17,6 +17,9 @@ use tui::widgets::{
|
|||||||
use tui::{symbols, Terminal};
|
use tui::{symbols, Terminal};
|
||||||
|
|
||||||
pub(crate) struct UI {
|
pub(crate) struct UI {
|
||||||
|
pub(crate) tokenizer_name: String,
|
||||||
|
pub(crate) sequence_length: u32,
|
||||||
|
pub(crate) decode_length: u32,
|
||||||
pub(crate) n_run: usize,
|
pub(crate) n_run: usize,
|
||||||
pub(crate) batch_size: Vec<u32>,
|
pub(crate) batch_size: Vec<u32>,
|
||||||
pub(crate) receiver: mpsc::Receiver<Message>,
|
pub(crate) receiver: mpsc::Receiver<Message>,
|
||||||
@ -117,10 +120,11 @@ impl UI {
|
|||||||
|
|
||||||
terminal.draw(|f| {
|
terminal.draw(|f| {
|
||||||
// Vertical layout
|
// Vertical layout
|
||||||
let row4 = Layout::default()
|
let row5 = Layout::default()
|
||||||
.direction(Direction::Vertical)
|
.direction(Direction::Vertical)
|
||||||
.constraints(
|
.constraints(
|
||||||
[
|
[
|
||||||
|
Constraint::Length(1),
|
||||||
Constraint::Length(3),
|
Constraint::Length(3),
|
||||||
Constraint::Length(3),
|
Constraint::Length(3),
|
||||||
Constraint::Length(13),
|
Constraint::Length(13),
|
||||||
@ -134,7 +138,7 @@ impl UI {
|
|||||||
let top = Layout::default()
|
let top = Layout::default()
|
||||||
.direction(Direction::Horizontal)
|
.direction(Direction::Horizontal)
|
||||||
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
.split(row4[0]);
|
.split(row5[2]);
|
||||||
|
|
||||||
// Mid row horizontal layout
|
// Mid row horizontal layout
|
||||||
let mid = Layout::default()
|
let mid = Layout::default()
|
||||||
@ -148,7 +152,7 @@ impl UI {
|
|||||||
]
|
]
|
||||||
.as_ref(),
|
.as_ref(),
|
||||||
)
|
)
|
||||||
.split(row4[2]);
|
.split(row5[3]);
|
||||||
|
|
||||||
// Left mid row vertical layout
|
// Left mid row vertical layout
|
||||||
let prefill_text = Layout::default()
|
let prefill_text = Layout::default()
|
||||||
@ -166,7 +170,36 @@ impl UI {
|
|||||||
let bottom = Layout::default()
|
let bottom = Layout::default()
|
||||||
.direction(Direction::Horizontal)
|
.direction(Direction::Horizontal)
|
||||||
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
.split(row4[3]);
|
.split(row5[4]);
|
||||||
|
|
||||||
|
// Title
|
||||||
|
let title = Block::default().borders(Borders::NONE).title(format!(
|
||||||
|
"Model: {} | Sequence Length: {} | Decode Length: {}",
|
||||||
|
self.tokenizer_name, self.sequence_length, self.decode_length
|
||||||
|
)).style(Style::default().add_modifier(Modifier::BOLD).fg(Color::White));
|
||||||
|
f.render_widget(title, row5[0]);
|
||||||
|
|
||||||
|
// Batch tabs
|
||||||
|
let titles = self
|
||||||
|
.batch_size
|
||||||
|
.iter()
|
||||||
|
.map(|b| {
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!("Batch: {b}"),
|
||||||
|
Style::default().fg(Color::White),
|
||||||
|
)])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tabs = Tabs::new(titles)
|
||||||
|
.block(Block::default().borders(Borders::ALL).title("Tabs"))
|
||||||
|
.select(current_tab_idx)
|
||||||
|
.style(Style::default().fg(Color::LightCyan))
|
||||||
|
.highlight_style(
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.bg(Color::Black),
|
||||||
|
);
|
||||||
|
f.render_widget(tabs, row5[1]);
|
||||||
|
|
||||||
// Total progress bar
|
// Total progress bar
|
||||||
let batch_gauge = progress_gauge(
|
let batch_gauge = progress_gauge(
|
||||||
@ -186,28 +219,6 @@ impl UI {
|
|||||||
);
|
);
|
||||||
f.render_widget(run_gauge, top[1]);
|
f.render_widget(run_gauge, top[1]);
|
||||||
|
|
||||||
// Batch tabs
|
|
||||||
let titles = self
|
|
||||||
.batch_size
|
|
||||||
.iter()
|
|
||||||
.map(|b| {
|
|
||||||
Spans::from(vec![
|
|
||||||
Span::raw(format!("Batch: {b}")), // Span::styled(first, Style::default().fg(Color::Yellow)),
|
|
||||||
// Span::styled(rest, Style::default().fg(Color::Green)),
|
|
||||||
])
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let tabs = Tabs::new(titles)
|
|
||||||
.block(Block::default().borders(Borders::ALL).title("Tabs"))
|
|
||||||
.select(current_tab_idx)
|
|
||||||
.style(Style::default().fg(Color::LightCyan))
|
|
||||||
.highlight_style(
|
|
||||||
Style::default()
|
|
||||||
.add_modifier(Modifier::BOLD)
|
|
||||||
.bg(Color::Black),
|
|
||||||
);
|
|
||||||
f.render_widget(tabs, row4[1]);
|
|
||||||
|
|
||||||
// Prefill text infos
|
// Prefill text infos
|
||||||
let (prefill_latency_statics, prefill_throughput_statics) = text_info(
|
let (prefill_latency_statics, prefill_throughput_statics) = text_info(
|
||||||
&mut prefill_latencies[current_tab_idx],
|
&mut prefill_latencies[current_tab_idx],
|
||||||
@ -384,7 +395,7 @@ fn latency_histogram<'a>(
|
|||||||
.block(
|
.block(
|
||||||
Block::default()
|
Block::default()
|
||||||
.title(format!("{name} latency histogram"))
|
.title(format!("{name} latency histogram"))
|
||||||
.style(Style::default().fg(Color::Yellow).bg(Color::Reset))
|
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
|
||||||
.borders(Borders::ALL),
|
.borders(Borders::ALL),
|
||||||
)
|
)
|
||||||
.data(histo_data_str.as_slice())
|
.data(histo_data_str.as_slice())
|
||||||
|
Loading…
Reference in New Issue
Block a user