mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
improving design
This commit is contained in:
parent
1c5d526943
commit
a1613e2518
194
benchmark/src/generation.rs
Normal file
194
benchmark/src/generation.rs
Normal file
@ -0,0 +1,194 @@
|
||||
use std::time::{Duration, Instant};
|
||||
use text_generation_client::{Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Prefill {
|
||||
pub(crate) latency: Duration,
|
||||
pub(crate) throughput: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Decode {
|
||||
pub(crate) decode_length: u32,
|
||||
pub(crate) latency: Duration,
|
||||
pub(crate) throughput: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Run {
|
||||
pub(crate) batch_size: u32,
|
||||
pub(crate) sequence_length: u32,
|
||||
pub(crate) prefill: Prefill,
|
||||
pub(crate) decode: Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Message {
|
||||
Warmup,
|
||||
Prefill(Prefill),
|
||||
Decode(Decode),
|
||||
Run(Run),
|
||||
EndBatch,
|
||||
}
|
||||
|
||||
pub(crate) async fn generation_task(
|
||||
tokenizer: Tokenizer,
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
client: ShardedClient,
|
||||
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||
) {
|
||||
tokio::select! {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
|
||||
if let Err(err) = res {
|
||||
run_sender.send(Err(err)).await.unwrap_or(());
|
||||
}
|
||||
},
|
||||
_ = shutdown_receiver.recv() => {}
|
||||
}
|
||||
;
|
||||
}
|
||||
|
||||
async fn generate_runs(tokenizer: Tokenizer,
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
mut client: ShardedClient,
|
||||
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||
) -> Result<(), ClientError> {
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
|
||||
for b in batch_size {
|
||||
for _ in 0..warmups {
|
||||
let (_, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||
let _ = decode(decode_batch, &mut client).await?;
|
||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||
}
|
||||
|
||||
for _ in 0..n_runs {
|
||||
let (prefill, decode_batch) = prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||
run_sender
|
||||
.send(Ok(Message::Prefill(prefill.clone())))
|
||||
.await
|
||||
.unwrap_or(());
|
||||
|
||||
let decode = decode(decode_batch, &mut client).await?;
|
||||
|
||||
run_sender
|
||||
.send(Ok(Message::Decode(decode.clone())))
|
||||
.await
|
||||
.unwrap_or(());
|
||||
|
||||
run_sender.send(Ok(Message::Run(Run {
|
||||
batch_size: b,
|
||||
sequence_length,
|
||||
prefill,
|
||||
decode,
|
||||
}))).await.unwrap_or(());
|
||||
}
|
||||
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prefill(
|
||||
sequence: String,
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, Batch), ClientError> {
|
||||
let requests = (0..batch_size)
|
||||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
inputs: sequence.clone(),
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
requests,
|
||||
size: batch_size,
|
||||
};
|
||||
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||
let latency = start_time.elapsed();
|
||||
let throughput = batch_size as f64
|
||||
/ latency.as_secs_f64();
|
||||
|
||||
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||
|
||||
let step = Prefill {
|
||||
latency,
|
||||
throughput,
|
||||
};
|
||||
|
||||
Ok((step, decode_batch))
|
||||
}
|
||||
|
||||
async fn decode(
|
||||
batch: Batch,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<Decode, ClientError> {
|
||||
let mut decode_length = 0;
|
||||
let start_time = Instant::now();
|
||||
let batch_size = batch.size;
|
||||
|
||||
let mut next_batch = Some(batch);
|
||||
while let Some(batch) = next_batch {
|
||||
let result = client.decode(vec![batch]).await?;
|
||||
next_batch = result.1;
|
||||
decode_length += 1;
|
||||
}
|
||||
let latency = start_time.elapsed();
|
||||
let throughput = (batch_size * decode_length) as f64
|
||||
/ latency.as_secs_f64();
|
||||
|
||||
let step = Decode {
|
||||
decode_length,
|
||||
latency,
|
||||
throughput,
|
||||
};
|
||||
Ok(step)
|
||||
}
|
||||
|
||||
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||
// Repeat lorem ipsum to cover sequence length
|
||||
let string_sequence =
|
||||
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
||||
// Encode sequence
|
||||
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||
// Truncate to sequence_length
|
||||
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||
// Decode
|
||||
tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.unwrap()
|
||||
}
|
@ -2,46 +2,13 @@ extern crate core;
|
||||
|
||||
mod ui;
|
||||
mod utils;
|
||||
mod generation;
|
||||
|
||||
use crate::ui::UI;
|
||||
use std::time::{Duration, Instant};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use text_generation_client::ShardedClient;
|
||||
|
||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Prefill {
|
||||
batch_size: u32,
|
||||
sequence_length: u32,
|
||||
latency: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Decode {
|
||||
batch_size: u32,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
latency: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Run {
|
||||
prefill: Prefill,
|
||||
decode: Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Message {
|
||||
Prefill(Prefill),
|
||||
Decode(Decode),
|
||||
IncreaseRun,
|
||||
IncreaseBatch,
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
tokenizer_name: String,
|
||||
@ -51,10 +18,15 @@ pub async fn run(
|
||||
decode_length: u32,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
mut client: ShardedClient,
|
||||
client: ShardedClient,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (ui_sender, ui_receiver) = mpsc::channel(8);
|
||||
let (shutdown_sender, mut shutdown_receiver) = broadcast::channel(1);
|
||||
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||
let (shutdown_sender, shutdown_receiver) = broadcast::channel(1);
|
||||
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
||||
|
||||
tokio::spawn(
|
||||
generation::generation_task(tokenizer, batch_size.clone(), sequence_length, decode_length, n_runs, warmups, client, run_sender, shutdown_receiver, shutdown_guard_sender.clone()),
|
||||
);
|
||||
|
||||
tokio::spawn(
|
||||
UI {
|
||||
@ -62,157 +34,18 @@ pub async fn run(
|
||||
decode_length,
|
||||
sequence_length,
|
||||
n_run: n_runs,
|
||||
batch_size: batch_size.clone(),
|
||||
receiver: ui_receiver,
|
||||
batch_size: batch_size,
|
||||
receiver: run_receiver,
|
||||
shutdown_sender,
|
||||
_shutdown_guard_sender: shutdown_guard_sender.clone()
|
||||
}
|
||||
.draw(),
|
||||
);
|
||||
|
||||
let mut runs = Vec::with_capacity(batch_size.len() * n_runs);
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
drop (shutdown_guard_sender);
|
||||
|
||||
for b in batch_size {
|
||||
for _ in 0..warmups {
|
||||
let (_, decode_batch) = tokio::select! {
|
||||
res = run_prefill(sequence.clone(), sequence_length, 1, decode_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let _ = tokio::select! {
|
||||
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for _ in 0..n_runs {
|
||||
let (prefill, decode_batch) = tokio::select! {
|
||||
res = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
ui_sender
|
||||
.send(Message::Prefill(prefill.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let decode = tokio::select! {
|
||||
res = run_decode(decode_batch, sequence_length, &mut client) => res?,
|
||||
_ = shutdown_receiver.recv() => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
ui_sender
|
||||
.send(Message::Decode(decode.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
runs.push(Run { prefill, decode });
|
||||
|
||||
ui_sender.send(Message::IncreaseRun).await.unwrap();
|
||||
}
|
||||
ui_sender.send(Message::IncreaseBatch).await.unwrap();
|
||||
}
|
||||
|
||||
// Signal the UI that we are done
|
||||
drop(ui_sender);
|
||||
|
||||
// Wait for UI shutdown signal
|
||||
let _ = shutdown_receiver.recv().await;
|
||||
// Wait for tasks to shutdown
|
||||
let _ = shutdown_guard_receiver.recv().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_prefill(
|
||||
sequence: String,
|
||||
sequence_length: u32,
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, Batch), ClientError> {
|
||||
let requests = (0..batch_size)
|
||||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
inputs: sequence.clone(),
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
requests,
|
||||
size: batch_size,
|
||||
};
|
||||
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||
let elasped = start_time.elapsed();
|
||||
|
||||
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||
|
||||
let step = Prefill {
|
||||
batch_size,
|
||||
sequence_length,
|
||||
latency: elasped,
|
||||
};
|
||||
|
||||
Ok((step, decode_batch))
|
||||
}
|
||||
|
||||
async fn run_decode(
|
||||
batch: Batch,
|
||||
sequence_length: u32,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<Decode, ClientError> {
|
||||
let batch_size = batch.size;
|
||||
let mut decode_length = 0;
|
||||
let start_time = Instant::now();
|
||||
|
||||
let mut next_batch = Some(batch);
|
||||
while let Some(batch) = next_batch {
|
||||
let result = client.decode(vec![batch]).await?;
|
||||
next_batch = result.1;
|
||||
decode_length += 1;
|
||||
}
|
||||
let elapsed = start_time.elapsed();
|
||||
let step = Decode {
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
latency: elapsed,
|
||||
};
|
||||
Ok(step)
|
||||
}
|
||||
|
||||
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||
// Repeat lorem ipsum to cover sequence length
|
||||
let string_sequence =
|
||||
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
||||
// Encode sequence
|
||||
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||
// Truncate to sequence_length
|
||||
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||
// Decode
|
||||
tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.unwrap()
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
/// Inspired by https://github.com/hatoo/oha/blob/master/src/monitor.rs
|
||||
use crate::Message;
|
||||
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
|
||||
use crossterm::{event, ExecutableCommand};
|
||||
use std::io;
|
||||
@ -15,6 +14,8 @@ use tui::widgets::{
|
||||
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||
};
|
||||
use tui::{symbols, Terminal};
|
||||
use text_generation_client::ClientError;
|
||||
use crate::generation::Message;
|
||||
|
||||
pub(crate) struct UI {
|
||||
pub(crate) tokenizer_name: String,
|
||||
@ -22,8 +23,9 @@ pub(crate) struct UI {
|
||||
pub(crate) decode_length: u32,
|
||||
pub(crate) n_run: usize,
|
||||
pub(crate) batch_size: Vec<u32>,
|
||||
pub(crate) receiver: mpsc::Receiver<Message>,
|
||||
pub(crate) receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||
pub(crate) shutdown_sender: broadcast::Sender<()>,
|
||||
pub(crate) _shutdown_guard_sender: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl UI {
|
||||
@ -57,6 +59,7 @@ impl UI {
|
||||
let mut completed_runs: Vec<usize> = (0..self.batch_size.len()).map(|_| 0).collect();
|
||||
let mut completed_batch = 0;
|
||||
let mut current_batch_idx = 0;
|
||||
let mut is_error = false;
|
||||
|
||||
let mut terminal = {
|
||||
let backend = CrosstermBackend::new(io::stdout());
|
||||
@ -68,23 +71,22 @@ impl UI {
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Ok(message) => match message {
|
||||
Ok(message) => {
|
||||
match message {
|
||||
Message::Prefill(step) => {
|
||||
let latency = step.latency.as_millis() as f64;
|
||||
let throughput = step.batch_size as f64 / step.latency.as_secs_f64();
|
||||
prefill_latencies[current_batch_idx].push(latency);
|
||||
prefill_throughputs[current_batch_idx].push(throughput);
|
||||
prefill_throughputs[current_batch_idx].push(step.throughput);
|
||||
}
|
||||
Message::Decode(step) => {
|
||||
let latency = step.latency.as_millis() as f64;
|
||||
let throughput = (step.batch_size * step.decode_length) as f64
|
||||
/ step.latency.as_secs_f64();
|
||||
decode_latencies[current_batch_idx].push(latency);
|
||||
decode_throughputs[current_batch_idx].push(throughput);
|
||||
decode_throughputs[current_batch_idx].push(step.throughput);
|
||||
}
|
||||
Message::IncreaseRun => {
|
||||
Message::Run(_) => {
|
||||
completed_runs[current_batch_idx] += 1;
|
||||
}
|
||||
Message::IncreaseBatch => {
|
||||
Message::EndBatch => {
|
||||
prefill_batch_latency_throughput.push((
|
||||
prefill_latencies[current_batch_idx].iter().sum::<f64>()
|
||||
/ completed_runs[current_batch_idx] as f64,
|
||||
@ -103,6 +105,10 @@ impl UI {
|
||||
current_batch_idx += 1;
|
||||
}
|
||||
}
|
||||
Message::Warmup => {}
|
||||
}
|
||||
}
|
||||
Err(_) => is_error = true
|
||||
},
|
||||
Err(TryRecvError::Empty) => {
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user