From a1613e25180ff688dcc19ba0c7de670c3846a18b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 30 Mar 2023 10:35:18 +0200 Subject: [PATCH] improving design --- benchmark/src/generation.rs | 194 ++++++++++++++++++++++++++++++++++ benchmark/src/lib.rs | 203 ++++-------------------------------- benchmark/src/ui.rs | 80 +++++++------- 3 files changed, 255 insertions(+), 222 deletions(-) create mode 100644 benchmark/src/generation.rs diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs new file mode 100644 index 00000000..bdc7d084 --- /dev/null +++ b/benchmark/src/generation.rs @@ -0,0 +1,194 @@ +use std::time::{Duration, Instant}; +use text_generation_client::{Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters}; +use tokenizers::{Tokenizer, TruncationDirection}; +use tokio::sync::{broadcast, mpsc}; + +const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; + +#[derive(Debug, Clone)] +pub(crate) struct Prefill { + pub(crate) latency: Duration, + pub(crate) throughput: f64, +} + +#[derive(Debug, Clone)] +pub(crate) struct Decode { + pub(crate) decode_length: u32, + pub(crate) latency: Duration, + pub(crate) throughput: f64, +} + +#[derive(Debug)] +pub(crate) struct Run { + pub(crate) batch_size: u32, + pub(crate) sequence_length: u32, + pub(crate) prefill: Prefill, + pub(crate) decode: Decode, +} + +#[derive(Debug)] +pub(crate) enum Message { + Warmup, + Prefill(Prefill), + Decode(Decode), + Run(Run), + EndBatch, +} + +pub(crate) async fn generation_task( + tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + n_runs: usize, + warmups: usize, + client: ShardedClient, + run_sender: mpsc::Sender>, + mut shutdown_receiver: broadcast::Receiver<()>, + _shutdown_guard_sender: mpsc::Sender<()>, +) { + tokio::select! { + res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => { + if let Err(err) = res { + run_sender.send(Err(err)).await.unwrap_or(()); + } + }, + _ = shutdown_receiver.recv() => {} + } + ; +} + +async fn generate_runs(tokenizer: Tokenizer, + batch_size: Vec, + sequence_length: u32, + decode_length: u32, + n_runs: usize, + warmups: usize, + mut client: ShardedClient, + run_sender: mpsc::Sender>, +) -> Result<(), ClientError> { + let sequence = create_sequence(sequence_length, tokenizer); + + for b in batch_size { + for _ in 0..warmups { + let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; + let _ = decode(decode_batch, &mut client).await?; + run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); + } + + for _ in 0..n_runs { + let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?; + run_sender + .send(Ok(Message::Prefill(prefill.clone()))) + .await + .unwrap_or(()); + + let decode = decode(decode_batch, &mut client).await?; + + run_sender + .send(Ok(Message::Decode(decode.clone()))) + .await + .unwrap_or(()); + + run_sender.send(Ok(Message::Run(Run { + batch_size: b, + sequence_length, + prefill, + decode, + }))).await.unwrap_or(()); + } + run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(()); + } + Ok(()) +} + +async fn prefill( + sequence: String, + batch_size: u32, + decode_length: u32, + client: &mut ShardedClient, +) -> Result<(Prefill, Batch), ClientError> { + let requests = (0..batch_size) + .map(|id| Request { + id: id.into(), + inputs: sequence.clone(), + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + watermark: false, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: decode_length, + stop_sequences: vec![], + ignore_eos_token: true, + }), + }) + .collect(); + + let batch = Batch { + id: 0, + requests, + size: batch_size, + }; + + let start_time = Instant::now(); + let (_, decode_batch) = client.prefill(batch.clone()).await?; + let latency = start_time.elapsed(); + let throughput = batch_size as f64 + / latency.as_secs_f64(); + + let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); + + let step = Prefill { + latency, + throughput, + }; + + Ok((step, decode_batch)) +} + +async fn decode( + batch: Batch, + client: &mut ShardedClient, +) -> Result { + let mut decode_length = 0; + let start_time = Instant::now(); + let batch_size = batch.size; + + let mut next_batch = Some(batch); + while let Some(batch) = next_batch { + let result = client.decode(vec![batch]).await?; + next_batch = result.1; + decode_length += 1; + } + let latency = start_time.elapsed(); + let throughput = (batch_size * decode_length) as f64 + / latency.as_secs_f64(); + + let step = Decode { + decode_length, + latency, + throughput, + }; + Ok(step) +} + +fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { + let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len(); + // Repeat lorem ipsum to cover sequence length + let string_sequence = + LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len()); + // Encode sequence + let mut encoding = tokenizer.encode(string_sequence, true).unwrap(); + // Truncate to sequence_length + encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); + // Decode + tokenizer + .decode(Vec::from(encoding.get_ids()), false) + .unwrap() +} diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index d30745c1..935c11dc 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -2,46 +2,13 @@ extern crate core; mod ui; mod utils; +mod generation; use crate::ui::UI; -use std::time::{Duration, Instant}; -use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, - StoppingCriteriaParameters, -}; -use tokenizers::{Tokenizer, TruncationDirection}; +use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; +use text_generation_client::ShardedClient; -const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; - -#[derive(Debug, Clone)] -pub(crate) struct Prefill { - batch_size: u32, - sequence_length: u32, - latency: Duration, -} - -#[derive(Debug, Clone)] -pub(crate) struct Decode { - batch_size: u32, - sequence_length: u32, - decode_length: u32, - latency: Duration, -} - -#[derive(Debug)] -pub(crate) struct Run { - prefill: Prefill, - decode: Decode, -} - -#[derive(Debug)] -pub(crate) enum Message { - Prefill(Prefill), - Decode(Decode), - IncreaseRun, - IncreaseBatch, -} pub async fn run( tokenizer_name: String, @@ -51,10 +18,15 @@ pub async fn run( decode_length: u32, n_runs: usize, warmups: usize, - mut client: ShardedClient, + client: ShardedClient, ) -> Result<(), Box> { - let (ui_sender, ui_receiver) = mpsc::channel(8); - let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1); + let (run_sender, run_receiver) = mpsc::channel(8); + let (shutdown_sender, shutdown_receiver) = broadcast::channel(1); + let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1); + + tokio::spawn( + generation::generation_task(tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, client, run_sender, shutdown_receiver, shutdown_guard_sender.clone()), + ); tokio::spawn( UI { @@ -62,157 +34,18 @@ pub async fn run( decode_length, sequence_length, n_run: n_runs, - batch_size: batch_size.clone(), - receiver: ui_receiver, + batch_size: batch_size, + receiver: run_receiver, shutdown_sender, + _shutdown_guard_sender: shutdown_guard_sender.clone() } - .draw(), + .draw(), ); - let mut runs = Vec::with_capacity(batch_size.len() * n_runs); - let sequence = create_sequence(sequence_length, tokenizer); + drop (shutdown_guard_sender); - for b in batch_size { - for _ in 0..warmups { - let (_, decode_batch) = tokio::select! { - res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - let _ = tokio::select! { - res = run_decode(decode_batch, sequence_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - } - - for _ in 0..n_runs { - let (prefill, decode_batch) = tokio::select! { - res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - ui_sender - .send(Message::Prefill(prefill.clone())) - .await - .unwrap(); - - let decode = tokio::select! { - res = run_decode(decode_batch, sequence_length, &mut client) => res?, - _ = shutdown_receiver.recv() => { - return Ok(()); - } - }; - - ui_sender - .send(Message::Decode(decode.clone())) - .await - .unwrap(); - runs.push(Run { prefill, decode }); - - ui_sender.send(Message::IncreaseRun).await.unwrap(); - } - ui_sender.send(Message::IncreaseBatch).await.unwrap(); - } - - // Signal the UI that we are done - drop(ui_sender); - - // Wait for UI shutdown signal - let _ = shutdown_receiver.recv().await; + // Wait for tasks to shutdown + let _ = shutdown_guard_receiver.recv().await; Ok(()) } - -async fn run_prefill( - sequence: String, - sequence_length: u32, - batch_size: u32, - decode_length: u32, - client: &mut ShardedClient, -) -> Result<(Prefill, Batch), ClientError> { - let requests = (0..batch_size) - .map(|id| Request { - id: id.into(), - inputs: sequence.clone(), - parameters: Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, - seed: 0, - repetition_penalty: 1.0, - watermark: false, - }), - stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: decode_length, - stop_sequences: vec![], - ignore_eos_token: true, - }), - }) - .collect(); - - let batch = Batch { - id: 0, - requests, - size: batch_size, - }; - - let start_time = Instant::now(); - let (_, decode_batch) = client.prefill(batch.clone()).await?; - let elasped = start_time.elapsed(); - - let decode_batch = decode_batch.expect("decode_batch is None. This is a bug."); - - let step = Prefill { - batch_size, - sequence_length, - latency: elasped, - }; - - Ok((step, decode_batch)) -} - -async fn run_decode( - batch: Batch, - sequence_length: u32, - client: &mut ShardedClient, -) -> Result { - let batch_size = batch.size; - let mut decode_length = 0; - let start_time = Instant::now(); - - let mut next_batch = Some(batch); - while let Some(batch) = next_batch { - let result = client.decode(vec![batch]).await?; - next_batch = result.1; - decode_length += 1; - } - let elapsed = start_time.elapsed(); - let step = Decode { - batch_size, - sequence_length, - decode_length, - latency: elapsed, - }; - Ok(step) -} - -fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { - let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len(); - // Repeat lorem ipsum to cover sequence length - let string_sequence = - LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len()); - // Encode sequence - let mut encoding = tokenizer.encode(string_sequence, true).unwrap(); - // Truncate to sequence_length - encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); - // Decode - tokenizer - .decode(Vec::from(encoding.get_ids()), false) - .unwrap() -} diff --git a/benchmark/src/ui.rs b/benchmark/src/ui.rs index adea0e23..ed2875ab 100644 --- a/benchmark/src/ui.rs +++ b/benchmark/src/ui.rs @@ -1,5 +1,4 @@ /// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs -use crate::Message; use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers}; use crossterm::{event, ExecutableCommand}; use std::io; @@ -15,6 +14,8 @@ use tui::widgets::{ Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, }; use tui::{symbols, Terminal}; +use text_generation_client::ClientError; +use crate::generation::Message; pub(crate) struct UI { pub(crate) tokenizer_name: String, @@ -22,8 +23,9 @@ pub(crate) struct UI { pub(crate) decode_length: u32, pub(crate) n_run: usize, pub(crate) batch_size: Vec, - pub(crate) receiver: mpsc::Receiver, + pub(crate) receiver: mpsc::Receiver>, pub(crate) shutdown_sender: broadcast::Sender<()>, + pub(crate) _shutdown_guard_sender: mpsc::Sender<()>, } impl UI { @@ -57,6 +59,7 @@ impl UI { let mut completed_runs: Vec = (0..self.batch_size.len()).map(|_| 0).collect(); let mut completed_batch = 0; let mut current_batch_idx = 0; + let mut is_error = false; let mut terminal = { let backend = CrosstermBackend::new(io::stdout()); @@ -68,41 +71,44 @@ impl UI { loop { match self.receiver.try_recv() { Ok(message) => match message { - Message::Prefill(step) => { - let latency = step.latency.as_millis() as f64; - let throughput = step.batch_size as f64 / step.latency.as_secs_f64(); - prefill_latencies[current_batch_idx].push(latency); - prefill_throughputs[current_batch_idx].push(throughput); - } - Message::Decode(step) => { - let latency = step.latency.as_millis() as f64; - let throughput = (step.batch_size * step.decode_length) as f64 - / step.latency.as_secs_f64(); - decode_latencies[current_batch_idx].push(latency); - decode_throughputs[current_batch_idx].push(throughput); - } - Message::IncreaseRun => { - completed_runs[current_batch_idx] += 1; - } - Message::IncreaseBatch => { - prefill_batch_latency_throughput.push(( - prefill_latencies[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - prefill_throughputs[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - )); - decode_batch_latency_throughput.push(( - decode_latencies[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - decode_throughputs[current_batch_idx].iter().sum::() - / completed_runs[current_batch_idx] as f64, - )); + Ok(message) => { + match message { + Message::Prefill(step) => { + let latency = step.latency.as_millis() as f64; + prefill_latencies[current_batch_idx].push(latency); + prefill_throughputs[current_batch_idx].push(step.throughput); + } + Message::Decode(step) => { + let latency = step.latency.as_millis() as f64; + decode_latencies[current_batch_idx].push(latency); + decode_throughputs[current_batch_idx].push(step.throughput); + } + Message::Run(_) => { + completed_runs[current_batch_idx] += 1; + } + Message::EndBatch => { + prefill_batch_latency_throughput.push(( + prefill_latencies[current_batch_idx].iter().sum::() + / completed_runs[current_batch_idx] as f64, + prefill_throughputs[current_batch_idx].iter().sum::() + / completed_runs[current_batch_idx] as f64, + )); + decode_batch_latency_throughput.push(( + decode_latencies[current_batch_idx].iter().sum::() + / completed_runs[current_batch_idx] as f64, + decode_throughputs[current_batch_idx].iter().sum::() + / completed_runs[current_batch_idx] as f64, + )); - completed_batch += 1; - if current_batch_idx < self.batch_size.len() - 1 { - current_batch_idx += 1; + completed_batch += 1; + if current_batch_idx < self.batch_size.len() - 1 { + current_batch_idx += 1; + } + } + Message::Warmup => {} } } + Err(_) => is_error = true }, Err(TryRecvError::Empty) => { break; @@ -130,7 +136,7 @@ impl UI { Constraint::Length(13), Constraint::Min(10), ] - .as_ref(), + .as_ref(), ) .split(f.size()); @@ -150,7 +156,7 @@ impl UI { Constraint::Percentage(20), Constraint::Percentage(30), ] - .as_ref(), + .as_ref(), ) .split(row5[3]); @@ -235,7 +241,7 @@ impl UI { } else { (mid[1].width as usize - 2) / (histo_width + 1) } - .max(2); + .max(2); let histo_data = latency_histogram_data(&prefill_latencies[current_tab_idx], bins); let histo_data_str: Vec<(&str, u64)> =