text-generation-inference/benchmark/src/generation.rs
Nicolas Patry e415b690a6
Lots of improvements (Still 2 allocators) (#2449)
* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2024-08-29 16:29:01 +02:00

238 lines
7.4 KiB
Rust

use std::time::{Duration, Instant};
use text_generation_client::v3::{
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) latency: Duration,
pub(crate) token_latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
EndRun,
EndBatch,
}
/// Benchmarking task
#[allow(clippy::too_many_arguments)]
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
}
/// Benchmark prefill/decode
#[allow(clippy::too_many_arguments)]
async fn generate_runs(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
let (_, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
let _ = decode(decode_batch, &mut client).await?;
// Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
// Send prefill message
run_sender
.send(Ok(Message::Prefill(prefill)))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
// Send decode message
run_sender
.send(Ok(Message::Decode(decode)))
.await
.unwrap_or(());
// Send run ended message
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
}
// Batch ended
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
// Run a prefill step
async fn prefill(
sequence: String,
sequence_length: u32,
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(),
truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
prefix_len: 0,
adapter_id: None,
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
};
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();
// Compute throughput from latency and batch size
let throughput = batch_size as f64 / latency.as_secs_f64();
// Decode batch cannot be empty
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
/// Run a full decode
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;
let start_time = Instant::now();
// Full decode over decode length
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
// Get latency
let latency = start_time.elapsed();
let token_latency = latency / decode_length;
// Compute throughput from latency, batch size and decode length
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
let step = Decode {
latency,
token_latency,
throughput,
};
Ok(step)
}
/// Create a dummy sequence of the correct length
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer.decode(encoding.get_ids(), false).unwrap()
}