mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
wip
This commit is contained in:
parent
a28a8ebdb5
commit
c0d793d2ca
@ -1,21 +1,22 @@
|
|||||||
mod ui;
|
mod ui;
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::time::{Duration, Instant};
|
||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters};
|
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
|
||||||
use time::Instant;
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use crate::ui::UI;
|
use crate::ui::UI;
|
||||||
|
|
||||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||||
|
|
||||||
enum Step {
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum Step {
|
||||||
Prefill,
|
Prefill,
|
||||||
Decode,
|
Decode,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Run {
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Run {
|
||||||
step: Step,
|
step: Step,
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
@ -29,10 +30,8 @@ pub async fn run(
|
|||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
runs: usize,
|
runs: usize,
|
||||||
// mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// let prefill_runs = benchmark_prefill(&tokenizer, &batch_size, &sequence_length, &decode_length, runs, &mut client).await;
|
|
||||||
|
|
||||||
let (sender, receiver) = mpsc::channel(8);
|
let (sender, receiver) = mpsc::channel(8);
|
||||||
|
|
||||||
|
|
||||||
@ -45,88 +44,110 @@ pub async fn run(
|
|||||||
}.draw()
|
}.draw()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
for n in 0..runs {
|
|
||||||
sender.send(()).await.unwrap();
|
for b in batch_size {
|
||||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
for _ in 0..runs {
|
||||||
|
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||||
|
sender.send(run).await.unwrap();
|
||||||
|
|
||||||
|
let run = run_decode(decode_batch, sequence_length, &mut client).await?;
|
||||||
|
sender.send(run).await.unwrap();
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
drop(sender);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// async fn benchmark_prefill(tokenizer: &Tokenizer,
|
async fn run_prefill(
|
||||||
// batch_size: &Vec<u32>,
|
sequence: String,
|
||||||
// sequence_length: u32,
|
sequence_length: u32,
|
||||||
// decode_length: u32,
|
batch_size: u32,
|
||||||
// runs: u32,
|
decode_length: u32,
|
||||||
// client: &mut ShardedClient) -> Vec<Run> {
|
client: &mut ShardedClient) -> Result<(Run, Batch), ClientError> {
|
||||||
// let mut results = Vec::new();
|
let requests = (0..batch_size).map(|id| {
|
||||||
//
|
Request {
|
||||||
// let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
id: id.into(),
|
||||||
//
|
inputs: sequence.clone(),
|
||||||
// for s in sequence_length {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
// let sequence = create_sequence(s, lorem_ipsum_length, tokenizer);
|
temperature: 1.0,
|
||||||
// for b in batch_size {
|
top_k: 0,
|
||||||
// for d in decode_length {
|
top_p: 1.0,
|
||||||
// let requests = (0..*b).map(|id| {
|
typical_p: 1.0,
|
||||||
// Request {
|
do_sample: false,
|
||||||
// id: id.into(),
|
seed: 0,
|
||||||
// inputs: sequence.clone(),
|
repetition_penalty: 1.0,
|
||||||
// input_length: *s,
|
watermark: false,
|
||||||
// parameters: Some(NextTokenChooserParameters {
|
}),
|
||||||
// temperature: 1.0,
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
// top_k: 0,
|
max_new_tokens: decode_length,
|
||||||
// top_p: 1.0,
|
stop_sequences: vec![],
|
||||||
// typical_p: 1.0,
|
ignore_eos_token: true,
|
||||||
// do_sample: false,
|
}),
|
||||||
// seed: 0,
|
}
|
||||||
// repetition_penalty: 1.0,
|
}).collect();
|
||||||
// watermark: false,
|
|
||||||
// }),
|
let batch = Batch {
|
||||||
// stopping_parameters: Some(StoppingCriteriaParameters {
|
id: 0,
|
||||||
// max_new_tokens: *d,
|
requests,
|
||||||
// stop_sequences: vec![],
|
size: batch_size,
|
||||||
// ignore_eos_token: true,
|
};
|
||||||
// }),
|
|
||||||
// }
|
let start_time = Instant::now();
|
||||||
// }).collect();
|
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||||
//
|
let elasped = start_time.elapsed();
|
||||||
// let batch = Batch {
|
|
||||||
// id: 0,
|
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||||
// requests,
|
|
||||||
// size: *b,
|
let run = Run {
|
||||||
// };
|
step: Step::Prefill,
|
||||||
//
|
batch_size,
|
||||||
// for _ in 0..runs {
|
sequence_length,
|
||||||
// let start_time = Instant::now();
|
decode_length: 1,
|
||||||
// client.prefill(batch.clone()).await.unwrap();
|
time: elasped,
|
||||||
// let elasped = start_time.elapsed();
|
};
|
||||||
//
|
|
||||||
// client.clear_cache().await.unwrap();
|
Ok((run, decode_batch))
|
||||||
//
|
}
|
||||||
// results.push(Run {
|
|
||||||
// step: Step::Prefill,
|
async fn run_decode(batch: Batch, sequence_length: u32, client: &mut ShardedClient) -> Result<Run, ClientError> {
|
||||||
// batch_size: *b,
|
let batch_size = batch.size;
|
||||||
// sequence_length: *s,
|
let mut decode_length = 0;
|
||||||
// decode_length: *d,
|
let start_time = Instant::now();
|
||||||
// time: elasped,
|
|
||||||
// });
|
let mut next_batch = Some(batch);
|
||||||
// }
|
while let Some(batch) = next_batch {
|
||||||
// }
|
let result = client.decode(vec![batch]).await?;
|
||||||
// }
|
next_batch = result.1;
|
||||||
// }
|
decode_length += 1;
|
||||||
// results
|
}
|
||||||
// }
|
let elapsed = start_time.elapsed();
|
||||||
//
|
let run = Run {
|
||||||
// fn create_sequence(sequence_length: &u32, lorem_ipsum_length: usize, tokenizer: &Tokenizer) -> String {
|
step: Step::Decode,
|
||||||
// // Repeat lorem ipsum to cover sequence length
|
batch_size,
|
||||||
// let string_sequence = LOREM_IPSUM.repeat((0..*sequence_length).step_by(lorem_ipsum_length).len());
|
sequence_length,
|
||||||
// // Encode sequence
|
decode_length,
|
||||||
// let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
time: elapsed,
|
||||||
// // Truncate to sequence_length
|
};
|
||||||
// encoding.truncate(*sequence_length as usize, 0, TruncationDirection::Left);
|
Ok(run)
|
||||||
// // Decode
|
}
|
||||||
// tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap()
|
|
||||||
// }
|
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||||
|
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||||
|
// Repeat lorem ipsum to cover sequence length
|
||||||
|
let string_sequence = LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
||||||
|
// Encode sequence
|
||||||
|
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||||
|
// Truncate to sequence_length
|
||||||
|
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||||
|
// Decode
|
||||||
|
tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap()
|
||||||
|
}
|
@ -45,12 +45,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
{
|
{
|
||||||
// Load local tokenizer
|
// Load local tokenizer
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
Tokenizer::from_file(local_path.join("tokenizer.json")).expect("unable to load local tokenizer")
|
||||||
} else {
|
} else {
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).expect("unable to load hub tokenizer")
|
||||||
};
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
@ -60,14 +61,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
init_logging();
|
init_logging();
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
// let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
tracing::info!("Connect to model server");
|
||||||
// .await
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
// .expect("Could not connect to server");
|
.await
|
||||||
|
.expect("Could not connect to server");
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
// sharded_client
|
sharded_client
|
||||||
// .clear_cache()
|
.clear_cache()
|
||||||
// .await
|
.await
|
||||||
// .expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
text_generation_benchmark::run(
|
text_generation_benchmark::run(
|
||||||
@ -76,7 +78,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
runs,
|
runs,
|
||||||
// sharded_client,
|
sharded_client,
|
||||||
).await.unwrap();
|
).await.unwrap();
|
||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -14,12 +14,13 @@ use tui::text::{Span, Spans};
|
|||||||
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
||||||
use tui::Terminal;
|
use tui::Terminal;
|
||||||
use tokio::sync::mpsc::Receiver;
|
use tokio::sync::mpsc::Receiver;
|
||||||
|
use crate::{Run, Step};
|
||||||
|
|
||||||
pub struct UI {
|
pub(crate) struct UI {
|
||||||
pub n_run: usize,
|
pub(crate) n_run: usize,
|
||||||
pub n_batch: usize,
|
pub(crate) n_batch: usize,
|
||||||
pub n_batch_done: usize,
|
pub(crate) n_batch_done: usize,
|
||||||
pub run_receiver: Receiver<()>,
|
pub(crate) run_receiver: Receiver<Run>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UI {
|
impl UI {
|
||||||
@ -30,6 +31,10 @@ impl UI {
|
|||||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||||
|
|
||||||
|
let mut prefill_latency = Vec::new();
|
||||||
|
let mut prefill_throughput = Vec::new();
|
||||||
|
let mut decode_latency = Vec::new();
|
||||||
|
let mut decode_throughput = Vec::new();
|
||||||
let mut runs = Vec::new();
|
let mut runs = Vec::new();
|
||||||
|
|
||||||
let mut terminal = {
|
let mut terminal = {
|
||||||
@ -42,11 +47,20 @@ impl UI {
|
|||||||
loop {
|
loop {
|
||||||
match self.run_receiver.try_recv() {
|
match self.run_receiver.try_recv() {
|
||||||
Ok(run) => {
|
Ok(run) => {
|
||||||
// match report.as_ref() {
|
match run.step {
|
||||||
// Ok(report) => *status_dist.entry(report.status).or_default() += 1,
|
Step::Prefill => {
|
||||||
// Err(e) => *error_dist.entry(e.to_string()).or_default() += 1,
|
let latency = run.time.as_millis() as f64;
|
||||||
// }
|
let throughput = run.batch_size as f64 / run.time.as_secs_f64();
|
||||||
// all.push(report);
|
prefill_latency.push(latency);
|
||||||
|
prefill_throughput.push(throughput);
|
||||||
|
}
|
||||||
|
Step::Decode => {
|
||||||
|
let latency = run.time.as_millis() as f64;
|
||||||
|
let throughput = (run.batch_size * run.decode_length) as f64 / run.time.as_secs_f64();
|
||||||
|
decode_latency.push(latency);
|
||||||
|
decode_throughput.push(throughput);
|
||||||
|
}
|
||||||
|
}
|
||||||
runs.push(run);
|
runs.push(run);
|
||||||
}
|
}
|
||||||
Err(TryRecvError::Empty) => {
|
Err(TryRecvError::Empty) => {
|
||||||
@ -59,8 +73,6 @@ impl UI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let draw_start = Instant::now();
|
|
||||||
|
|
||||||
let batch_progress = (self.n_batch_done as f64 / self.n_batch as f64).clamp(0.0, 1.0);
|
let batch_progress = (self.n_batch_done as f64 / self.n_batch as f64).clamp(0.0, 1.0);
|
||||||
let run_progress = (runs.len() as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
let run_progress = (runs.len() as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
||||||
|
|
||||||
@ -128,10 +140,8 @@ impl UI {
|
|||||||
.ratio(run_progress);
|
.ratio(run_progress);
|
||||||
f.render_widget(run_gauge, top[1]);
|
f.render_widget(run_gauge, top[1]);
|
||||||
|
|
||||||
let data = vec![0.0];
|
let prefill_latency_texts = statis_spans(&prefill_latency, "ms", false);
|
||||||
|
let prefill_throughput_texts = statis_spans(&prefill_throughput, "tokens/secs", false);
|
||||||
let prefill_latency_texts = statis_spans(&data, "ms", false);
|
|
||||||
let prefill_throughput_texts = statis_spans(&data, "tokens/secs", false);
|
|
||||||
|
|
||||||
let prefill_latency_statics = Paragraph::new(prefill_latency_texts).block(
|
let prefill_latency_statics = Paragraph::new(prefill_latency_texts).block(
|
||||||
Block::default()
|
Block::default()
|
||||||
@ -146,6 +156,23 @@ impl UI {
|
|||||||
.borders(Borders::ALL),
|
.borders(Borders::ALL),
|
||||||
);
|
);
|
||||||
f.render_widget(prefill_throughput_statics, prefill_text[1]);
|
f.render_widget(prefill_throughput_statics, prefill_text[1]);
|
||||||
|
|
||||||
|
let decode_latency_texts = statis_spans(&decode_latency, "ms", false);
|
||||||
|
let decode_throughput_texts = statis_spans(&decode_throughput, "tokens/secs", false);
|
||||||
|
|
||||||
|
let decode_latency_statics = Paragraph::new(decode_latency_texts).block(
|
||||||
|
Block::default()
|
||||||
|
.title(Span::raw("Decode Latency"))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
);
|
||||||
|
f.render_widget(decode_latency_statics, decode_text[0]);
|
||||||
|
|
||||||
|
let decode_throughput_statics = Paragraph::new(decode_throughput_texts).block(
|
||||||
|
Block::default()
|
||||||
|
.title(Span::raw("Decode Throughput"))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
);
|
||||||
|
f.render_widget(decode_throughput_statics, decode_text[1]);
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let per_frame = Duration::from_secs(1) / 30 as u32;
|
let per_frame = Duration::from_secs(1) / 30 as u32;
|
||||||
@ -170,7 +197,7 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spa
|
|||||||
"Lowest: {:.4} {unit}",
|
"Lowest: {:.4} {unit}",
|
||||||
data
|
data
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| a.total_cmp(b))
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN)
|
.unwrap_or(&std::f64::NAN)
|
||||||
),
|
),
|
||||||
Style::default().fg(Color::Reset),
|
Style::default().fg(Color::Reset),
|
||||||
@ -180,7 +207,7 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spa
|
|||||||
"Highest: {:.4} {unit}",
|
"Highest: {:.4} {unit}",
|
||||||
data
|
data
|
||||||
.iter()
|
.iter()
|
||||||
.min_by(|a, b| a.total_cmp(b))
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
.unwrap_or(&std::f64::NAN)
|
.unwrap_or(&std::f64::NAN)
|
||||||
),
|
),
|
||||||
Style::default().fg(Color::Reset),
|
Style::default().fg(Color::Reset),
|
||||||
|
Loading…
Reference in New Issue
Block a user