mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
wip
This commit is contained in:
parent
a28a8ebdb5
commit
c0d793d2ca
@ -1,21 +1,22 @@
|
||||
mod ui;
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokio::time;
|
||||
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters};
|
||||
use time::Instant;
|
||||
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters, ClientError};
|
||||
use tokio::sync::mpsc;
|
||||
use crate::ui::UI;
|
||||
|
||||
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||
|
||||
enum Step {
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Step {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
struct Run {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Run {
|
||||
step: Step,
|
||||
batch_size: u32,
|
||||
sequence_length: u32,
|
||||
@ -29,10 +30,8 @@ pub async fn run(
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
runs: usize,
|
||||
// mut client: ShardedClient,
|
||||
mut client: ShardedClient,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let prefill_runs = benchmark_prefill(&tokenizer, &batch_size, &sequence_length, &decode_length, runs, &mut client).await;
|
||||
|
||||
let (sender, receiver) = mpsc::channel(8);
|
||||
|
||||
|
||||
@ -45,88 +44,110 @@ pub async fn run(
|
||||
}.draw()
|
||||
);
|
||||
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
|
||||
for n in 0..runs {
|
||||
sender.send(()).await.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
for b in batch_size {
|
||||
for _ in 0..runs {
|
||||
let (run, decode_batch) = run_prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||
sender.send(run).await.unwrap();
|
||||
|
||||
let run = run_decode(decode_batch, sequence_length, &mut client).await?;
|
||||
sender.send(run).await.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
drop(sender);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
//
|
||||
// async fn benchmark_prefill(tokenizer: &Tokenizer,
|
||||
// batch_size: &Vec<u32>,
|
||||
// sequence_length: u32,
|
||||
// decode_length: u32,
|
||||
// runs: u32,
|
||||
// client: &mut ShardedClient) -> Vec<Run> {
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||
//
|
||||
// for s in sequence_length {
|
||||
// let sequence = create_sequence(s, lorem_ipsum_length, tokenizer);
|
||||
// for b in batch_size {
|
||||
// for d in decode_length {
|
||||
// let requests = (0..*b).map(|id| {
|
||||
// Request {
|
||||
// id: id.into(),
|
||||
// inputs: sequence.clone(),
|
||||
// input_length: *s,
|
||||
// parameters: Some(NextTokenChooserParameters {
|
||||
// temperature: 1.0,
|
||||
// top_k: 0,
|
||||
// top_p: 1.0,
|
||||
// typical_p: 1.0,
|
||||
// do_sample: false,
|
||||
// seed: 0,
|
||||
// repetition_penalty: 1.0,
|
||||
// watermark: false,
|
||||
// }),
|
||||
// stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
// max_new_tokens: *d,
|
||||
// stop_sequences: vec![],
|
||||
// ignore_eos_token: true,
|
||||
// }),
|
||||
// }
|
||||
// }).collect();
|
||||
//
|
||||
// let batch = Batch {
|
||||
// id: 0,
|
||||
// requests,
|
||||
// size: *b,
|
||||
// };
|
||||
//
|
||||
// for _ in 0..runs {
|
||||
// let start_time = Instant::now();
|
||||
// client.prefill(batch.clone()).await.unwrap();
|
||||
// let elasped = start_time.elapsed();
|
||||
//
|
||||
// client.clear_cache().await.unwrap();
|
||||
//
|
||||
// results.push(Run {
|
||||
// step: Step::Prefill,
|
||||
// batch_size: *b,
|
||||
// sequence_length: *s,
|
||||
// decode_length: *d,
|
||||
// time: elasped,
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// results
|
||||
// }
|
||||
//
|
||||
// fn create_sequence(sequence_length: &u32, lorem_ipsum_length: usize, tokenizer: &Tokenizer) -> String {
|
||||
// // Repeat lorem ipsum to cover sequence length
|
||||
// let string_sequence = LOREM_IPSUM.repeat((0..*sequence_length).step_by(lorem_ipsum_length).len());
|
||||
// // Encode sequence
|
||||
// let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||
// // Truncate to sequence_length
|
||||
// encoding.truncate(*sequence_length as usize, 0, TruncationDirection::Left);
|
||||
// // Decode
|
||||
// tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap()
|
||||
// }
|
||||
|
||||
async fn run_prefill(
|
||||
sequence: String,
|
||||
sequence_length: u32,
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
client: &mut ShardedClient) -> Result<(Run, Batch), ClientError> {
|
||||
let requests = (0..batch_size).map(|id| {
|
||||
Request {
|
||||
id: id.into(),
|
||||
inputs: sequence.clone(),
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
requests,
|
||||
size: batch_size,
|
||||
};
|
||||
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||
let elasped = start_time.elapsed();
|
||||
|
||||
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||
|
||||
let run = Run {
|
||||
step: Step::Prefill,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length: 1,
|
||||
time: elasped,
|
||||
};
|
||||
|
||||
Ok((run, decode_batch))
|
||||
}
|
||||
|
||||
async fn run_decode(batch: Batch, sequence_length: u32, client: &mut ShardedClient) -> Result<Run, ClientError> {
|
||||
let batch_size = batch.size;
|
||||
let mut decode_length = 0;
|
||||
let start_time = Instant::now();
|
||||
|
||||
let mut next_batch = Some(batch);
|
||||
while let Some(batch) = next_batch {
|
||||
let result = client.decode(vec![batch]).await?;
|
||||
next_batch = result.1;
|
||||
decode_length += 1;
|
||||
}
|
||||
let elapsed = start_time.elapsed();
|
||||
let run = Run {
|
||||
step: Step::Decode,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
time: elapsed,
|
||||
};
|
||||
Ok(run)
|
||||
}
|
||||
|
||||
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||
// Repeat lorem ipsum to cover sequence length
|
||||
let string_sequence = LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
||||
// Encode sequence
|
||||
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||
// Truncate to sequence_length
|
||||
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||
// Decode
|
||||
tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap()
|
||||
}
|
@ -45,12 +45,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||
{
|
||||
// Load local tokenizer
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).expect("unable to load local tokenizer")
|
||||
} else {
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).expect("unable to load hub tokenizer")
|
||||
};
|
||||
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
@ -60,14 +61,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
init_logging();
|
||||
|
||||
// Instantiate sharded client from the master unix socket
|
||||
// let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
// .await
|
||||
// .expect("Could not connect to server");
|
||||
tracing::info!("Connect to model server");
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
// sharded_client
|
||||
// .clear_cache()
|
||||
// .await
|
||||
// .expect("Unable to clear cache");
|
||||
sharded_client
|
||||
.clear_cache()
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
text_generation_benchmark::run(
|
||||
@ -76,7 +78,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
sequence_length,
|
||||
decode_length,
|
||||
runs,
|
||||
// sharded_client,
|
||||
sharded_client,
|
||||
).await.unwrap();
|
||||
});
|
||||
Ok(())
|
||||
|
@ -14,12 +14,13 @@ use tui::text::{Span, Spans};
|
||||
use tui::widgets::{BarChart, Block, Borders, Gauge, Paragraph};
|
||||
use tui::Terminal;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use crate::{Run, Step};
|
||||
|
||||
pub struct UI {
|
||||
pub n_run: usize,
|
||||
pub n_batch: usize,
|
||||
pub n_batch_done: usize,
|
||||
pub run_receiver: Receiver<()>,
|
||||
pub(crate) struct UI {
|
||||
pub(crate) n_run: usize,
|
||||
pub(crate) n_batch: usize,
|
||||
pub(crate) n_batch_done: usize,
|
||||
pub(crate) run_receiver: Receiver<Run>,
|
||||
}
|
||||
|
||||
impl UI {
|
||||
@ -30,6 +31,10 @@ impl UI {
|
||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||
|
||||
let mut prefill_latency = Vec::new();
|
||||
let mut prefill_throughput = Vec::new();
|
||||
let mut decode_latency = Vec::new();
|
||||
let mut decode_throughput = Vec::new();
|
||||
let mut runs = Vec::new();
|
||||
|
||||
let mut terminal = {
|
||||
@ -42,11 +47,20 @@ impl UI {
|
||||
loop {
|
||||
match self.run_receiver.try_recv() {
|
||||
Ok(run) => {
|
||||
// match report.as_ref() {
|
||||
// Ok(report) => *status_dist.entry(report.status).or_default() += 1,
|
||||
// Err(e) => *error_dist.entry(e.to_string()).or_default() += 1,
|
||||
// }
|
||||
// all.push(report);
|
||||
match run.step {
|
||||
Step::Prefill => {
|
||||
let latency = run.time.as_millis() as f64;
|
||||
let throughput = run.batch_size as f64 / run.time.as_secs_f64();
|
||||
prefill_latency.push(latency);
|
||||
prefill_throughput.push(throughput);
|
||||
}
|
||||
Step::Decode => {
|
||||
let latency = run.time.as_millis() as f64;
|
||||
let throughput = (run.batch_size * run.decode_length) as f64 / run.time.as_secs_f64();
|
||||
decode_latency.push(latency);
|
||||
decode_throughput.push(throughput);
|
||||
}
|
||||
}
|
||||
runs.push(run);
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
@ -59,8 +73,6 @@ impl UI {
|
||||
}
|
||||
}
|
||||
|
||||
let draw_start = Instant::now();
|
||||
|
||||
let batch_progress = (self.n_batch_done as f64 / self.n_batch as f64).clamp(0.0, 1.0);
|
||||
let run_progress = (runs.len() as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
||||
|
||||
@ -128,10 +140,8 @@ impl UI {
|
||||
.ratio(run_progress);
|
||||
f.render_widget(run_gauge, top[1]);
|
||||
|
||||
let data = vec![0.0];
|
||||
|
||||
let prefill_latency_texts = statis_spans(&data, "ms", false);
|
||||
let prefill_throughput_texts = statis_spans(&data, "tokens/secs", false);
|
||||
let prefill_latency_texts = statis_spans(&prefill_latency, "ms", false);
|
||||
let prefill_throughput_texts = statis_spans(&prefill_throughput, "tokens/secs", false);
|
||||
|
||||
let prefill_latency_statics = Paragraph::new(prefill_latency_texts).block(
|
||||
Block::default()
|
||||
@ -146,6 +156,23 @@ impl UI {
|
||||
.borders(Borders::ALL),
|
||||
);
|
||||
f.render_widget(prefill_throughput_statics, prefill_text[1]);
|
||||
|
||||
let decode_latency_texts = statis_spans(&decode_latency, "ms", false);
|
||||
let decode_throughput_texts = statis_spans(&decode_throughput, "tokens/secs", false);
|
||||
|
||||
let decode_latency_statics = Paragraph::new(decode_latency_texts).block(
|
||||
Block::default()
|
||||
.title(Span::raw("Decode Latency"))
|
||||
.borders(Borders::ALL),
|
||||
);
|
||||
f.render_widget(decode_latency_statics, decode_text[0]);
|
||||
|
||||
let decode_throughput_statics = Paragraph::new(decode_throughput_texts).block(
|
||||
Block::default()
|
||||
.title(Span::raw("Decode Throughput"))
|
||||
.borders(Borders::ALL),
|
||||
);
|
||||
f.render_widget(decode_throughput_statics, decode_text[1]);
|
||||
})?;
|
||||
|
||||
let per_frame = Duration::from_secs(1) / 30 as u32;
|
||||
@ -170,7 +197,7 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spa
|
||||
"Lowest: {:.4} {unit}",
|
||||
data
|
||||
.iter()
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
@ -180,7 +207,7 @@ fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str, color: bool) -> Vec<Spa
|
||||
"Highest: {:.4} {unit}",
|
||||
data
|
||||
.iter()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
|
Loading…
Reference in New Issue
Block a user