improving design

This commit is contained in:
OlivierDehaene 2023-03-30 10:35:18 +02:00
parent 1c5d526943
commit a1613e2518
3 changed files with 255 additions and 222 deletions

194
benchmark/src/generation.rs Normal file
View File

@ -0,0 +1,194 @@
use std::time::{Duration, Instant};
use text_generation_client::{Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) decode_length: u32,
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) struct Run {
pub(crate) batch_size: u32,
pub(crate) sequence_length: u32,
pub(crate) prefill: Prefill,
pub(crate) decode: Decode,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
Run(Run),
EndBatch,
}
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
;
}
async fn generate_runs(tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
for _ in 0..warmups {
let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
let _ = decode(decode_batch, &mut client).await?;
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
run_sender
.send(Ok(Message::Prefill(prefill.clone())))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
run_sender
.send(Ok(Message::Decode(decode.clone())))
.await
.unwrap_or(());
run_sender.send(Ok(Message::Run(Run {
batch_size: b,
sequence_length,
prefill,
decode,
}))).await.unwrap_or(());
}
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
async fn prefill(
sequence: String,
batch_size: u32,
decode_length: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> {
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
inputs: sequence.clone(),
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true,
}),
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
};
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
let latency = start_time.elapsed();
let throughput = batch_size as f64
/ latency.as_secs_f64();
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
async fn decode(
batch: Batch,
client: &mut ShardedClient,
) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let start_time = Instant::now();
let batch_size = batch.size;
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
let latency = start_time.elapsed();
let throughput = (batch_size * decode_length) as f64
/ latency.as_secs_f64();
let step = Decode {
decode_length,
latency,
throughput,
};
Ok(step)
}
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
}

View File

@ -2,46 +2,13 @@ extern crate core;
mod ui; mod ui;
mod utils; mod utils;
mod generation;
use crate::ui::UI; use crate::ui::UI;
use std::time::{Duration, Instant}; use tokenizers::Tokenizer;
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use text_generation_client::ShardedClient;
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
batch_size: u32,
sequence_length: u32,
latency: Duration,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
batch_size: u32,
sequence_length: u32,
decode_length: u32,
latency: Duration,
}
#[derive(Debug)]
pub(crate) struct Run {
prefill: Prefill,
decode: Decode,
}
#[derive(Debug)]
pub(crate) enum Message {
Prefill(Prefill),
Decode(Decode),
IncreaseRun,
IncreaseBatch,
}
pub async fn run( pub async fn run(
tokenizer_name: String, tokenizer_name: String,
@ -51,10 +18,15 @@ pub async fn run(
decode_length: u32, decode_length: u32,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
mut client: ShardedClient, client: ShardedClient,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let (ui_sender, ui_receiver) = mpsc::channel(8); let (run_sender, run_receiver) = mpsc::channel(8);
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1); let (shutdown_sender, shutdown_receiver) = broadcast::channel(1);
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
tokio::spawn(
generation::generation_task(tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, client, run_sender, shutdown_receiver, shutdown_guard_sender.clone()),
);
tokio::spawn( tokio::spawn(
UI { UI {
@ -62,157 +34,18 @@ pub async fn run(
decode_length, decode_length,
sequence_length, sequence_length,
n_run: n_runs, n_run: n_runs,
batch_size: batch_size.clone(), batch_size: batch_size,
receiver: ui_receiver, receiver: run_receiver,
shutdown_sender, shutdown_sender,
_shutdown_guard_sender: shutdown_guard_sender.clone()
} }
.draw(), .draw(),
); );
let mut runs = Vec::with_capacity(batch_size.len() * n_runs); drop (shutdown_guard_sender);
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size { // Wait for tasks to shutdown
for _ in 0..warmups { let _ = shutdown_guard_receiver.recv().await;
let (_, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
let _ = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
}
for _ in 0..n_runs {
let (prefill, decode_batch) = tokio::select! {
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
ui_sender
.send(Message::Prefill(prefill.clone()))
.await
.unwrap();
let decode = tokio::select! {
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
_ = shutdown_receiver.recv() => {
return Ok(());
}
};
ui_sender
.send(Message::Decode(decode.clone()))
.await
.unwrap();
runs.push(Run { prefill, decode });
ui_sender.send(Message::IncreaseRun).await.unwrap();
}
ui_sender.send(Message::IncreaseBatch).await.unwrap();
}
// Signal the UI that we are done
drop(ui_sender);
// Wait for UI shutdown signal
let _ = shutdown_receiver.recv().await;
Ok(()) Ok(())
} }
async fn run_prefill(
sequence: String,
sequence_length: u32,
batch_size: u32,
decode_length: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> {
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
inputs: sequence.clone(),
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true,
}),
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
};
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
let elasped = start_time.elapsed();
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
batch_size,
sequence_length,
latency: elasped,
};
Ok((step, decode_batch))
}
async fn run_decode(
batch: Batch,
sequence_length: u32,
client: &mut ShardedClient,
) -> Result<Decode, ClientError> {
let batch_size = batch.size;
let mut decode_length = 0;
let start_time = Instant::now();
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
let elapsed = start_time.elapsed();
let step = Decode {
batch_size,
sequence_length,
decode_length,
latency: elapsed,
};
Ok(step)
}
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
}

View File

@ -1,5 +1,4 @@
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs /// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
use crate::Message;
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers}; use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
use crossterm::{event, ExecutableCommand}; use crossterm::{event, ExecutableCommand};
use std::io; use std::io;
@ -15,6 +14,8 @@ use tui::widgets::{
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
}; };
use tui::{symbols, Terminal}; use tui::{symbols, Terminal};
use text_generation_client::ClientError;
use crate::generation::Message;
pub(crate) struct UI { pub(crate) struct UI {
pub(crate) tokenizer_name: String, pub(crate) tokenizer_name: String,
@ -22,8 +23,9 @@ pub(crate) struct UI {
pub(crate) decode_length: u32, pub(crate) decode_length: u32,
pub(crate) n_run: usize, pub(crate) n_run: usize,
pub(crate) batch_size: Vec<u32>, pub(crate) batch_size: Vec<u32>,
pub(crate) receiver: mpsc::Receiver<Message>, pub(crate) receiver: mpsc::Receiver<Result<Message, ClientError>>,
pub(crate) shutdown_sender: broadcast::Sender<()>, pub(crate) shutdown_sender: broadcast::Sender<()>,
pub(crate) _shutdown_guard_sender: mpsc::Sender<()>,
} }
impl UI { impl UI {
@ -57,6 +59,7 @@ impl UI {
let mut completed_runs: Vec<usize> = (0..self.batch_size.len()).map(|_| 0).collect(); let mut completed_runs: Vec<usize> = (0..self.batch_size.len()).map(|_| 0).collect();
let mut completed_batch = 0; let mut completed_batch = 0;
let mut current_batch_idx = 0; let mut current_batch_idx = 0;
let mut is_error = false;
let mut terminal = { let mut terminal = {
let backend = CrosstermBackend::new(io::stdout()); let backend = CrosstermBackend::new(io::stdout());
@ -68,41 +71,44 @@ impl UI {
loop { loop {
match self.receiver.try_recv() { match self.receiver.try_recv() {
Ok(message) => match message { Ok(message) => match message {
Message::Prefill(step) => { Ok(message) => {
let latency = step.latency.as_millis() as f64; match message {
let throughput = step.batch_size as f64 / step.latency.as_secs_f64(); Message::Prefill(step) => {
prefill_latencies[current_batch_idx].push(latency); let latency = step.latency.as_millis() as f64;
prefill_throughputs[current_batch_idx].push(throughput); prefill_latencies[current_batch_idx].push(latency);
} prefill_throughputs[current_batch_idx].push(step.throughput);
Message::Decode(step) => { }
let latency = step.latency.as_millis() as f64; Message::Decode(step) => {
let throughput = (step.batch_size * step.decode_length) as f64 let latency = step.latency.as_millis() as f64;
/ step.latency.as_secs_f64(); decode_latencies[current_batch_idx].push(latency);
decode_latencies[current_batch_idx].push(latency); decode_throughputs[current_batch_idx].push(step.throughput);
decode_throughputs[current_batch_idx].push(throughput); }
} Message::Run(_) => {
Message::IncreaseRun => { completed_runs[current_batch_idx] += 1;
completed_runs[current_batch_idx] += 1; }
} Message::EndBatch => {
Message::IncreaseBatch => { prefill_batch_latency_throughput.push((
prefill_batch_latency_throughput.push(( prefill_latencies[current_batch_idx].iter().sum::<f64>()
prefill_latencies[current_batch_idx].iter().sum::<f64>() / completed_runs[current_batch_idx] as f64,
/ completed_runs[current_batch_idx] as f64, prefill_throughputs[current_batch_idx].iter().sum::<f64>()
prefill_throughputs[current_batch_idx].iter().sum::<f64>() / completed_runs[current_batch_idx] as f64,
/ completed_runs[current_batch_idx] as f64, ));
)); decode_batch_latency_throughput.push((
decode_batch_latency_throughput.push(( decode_latencies[current_batch_idx].iter().sum::<f64>()
decode_latencies[current_batch_idx].iter().sum::<f64>() / completed_runs[current_batch_idx] as f64,
/ completed_runs[current_batch_idx] as f64, decode_throughputs[current_batch_idx].iter().sum::<f64>()
decode_throughputs[current_batch_idx].iter().sum::<f64>() / completed_runs[current_batch_idx] as f64,
/ completed_runs[current_batch_idx] as f64, ));
));
completed_batch += 1; completed_batch += 1;
if current_batch_idx < self.batch_size.len() - 1 { if current_batch_idx < self.batch_size.len() - 1 {
current_batch_idx += 1; current_batch_idx += 1;
}
}
Message::Warmup => {}
} }
} }
Err(_) => is_error = true
}, },
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => {
break; break;
@ -130,7 +136,7 @@ impl UI {
Constraint::Length(13), Constraint::Length(13),
Constraint::Min(10), Constraint::Min(10),
] ]
.as_ref(), .as_ref(),
) )
.split(f.size()); .split(f.size());
@ -150,7 +156,7 @@ impl UI {
Constraint::Percentage(20), Constraint::Percentage(20),
Constraint::Percentage(30), Constraint::Percentage(30),
] ]
.as_ref(), .as_ref(),
) )
.split(row5[3]); .split(row5[3]);
@ -235,7 +241,7 @@ impl UI {
} else { } else {
(mid[1].width as usize - 2) / (histo_width + 1) (mid[1].width as usize - 2) / (histo_width + 1)
} }
.max(2); .max(2);
let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins); let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins);
let histo_data_str: Vec<(&str, u64)> = let histo_data_str: Vec<(&str, u64)> =