text-generation-inference/benchmark/src/generation.rs
drbh 04e1af94d7
Enable multiple LoRa adapters (#2010)
* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------

Co-authored-by: Derek <datavistics@gmail.com>
2024-06-25 14:46:27 -04:00

236 lines
7.4 KiB
Rust

use std::time::{Duration, Instant};
use text_generation_client::v3::{
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) latency: Duration,
pub(crate) token_latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
EndRun,
EndBatch,
}
/// Benchmarking task
#[allow(clippy::too_many_arguments)]
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
}
/// Benchmark prefill/decode
#[allow(clippy::too_many_arguments)]
async fn generate_runs(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
let (_, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
let _ = decode(decode_batch, &mut client).await?;
// Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
// Send prefill message
run_sender
.send(Ok(Message::Prefill(prefill)))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
// Send decode message
run_sender
.send(Ok(Message::Decode(decode)))
.await
.unwrap_or(());
// Send run ended message
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
}
// Batch ended
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
// Run a prefill step
async fn prefill(
sequence: String,
sequence_length: u32,
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
};
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();
// Compute throughput from latency and batch size
let throughput = batch_size as f64 / latency.as_secs_f64();
// Decode batch cannot be empty
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
/// Run a full decode
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;
let start_time = Instant::now();
// Full decode over decode length
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
// Get latency
let latency = start_time.elapsed();
let token_latency = latency / decode_length;
// Compute throughput from latency, batch size and decode length
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
let step = Decode {
latency,
token_latency,
throughput,
};
Ok(step)
}
/// Create a dummy sequence of the correct length
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer.decode(encoding.get_ids(), false).unwrap()
}