mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
* Making prefix/flashinfer the default and testing the full release tests. * Include flashinfer in the docker. * Using prebuilt. * Allowing window_left_size (dummy version). * Disabling flashinfer/prefix caching on odd head_dim * Disable prefix caching for lora. * More specific codes. * Update lock * Updating integration tests with new values with FI/FD. Remove paged as a default too, and using FD everywhere. * Update cargo lock ? * Upgrade to 1.80 because of bitstream... * Everywhere 1.80 * Forgot last default place. * Apply suggestions from code review Co-authored-by: drbh <david.richard.holtz@gmail.com> * Updated flake lock * Tmp * Upgrade resolution system for less errors in resolution. * Remove lambda for cleaner function. * Handling debugger. * OVerride the env in server tests. * Is this enough to make it work ? * This seems to be working. * Downgrade some logs. * Fixing the default for vlm. * Don't enable prefix caching on VLM just yet. * Change `add_special_tokens` in order to have the correct tokens for chat input and not (since it's super important with the prefixing now) * Fixing prefix caching for flashdecoding. * Update all models. * Fixed flashinfer version. * add_special_tokens is internal only * Fixing seqlen with the new vlms. * Fixing the issue with `add_special_tokens` not being passed around. * Fixing the test. * Removing encoder_decoder (seq2seq). * Update the chat test. * Fixing the batching tokenization in flash causal lm. * Truncating left for radix purposes. * Oops this doesn't belong here. * Put back default pure shell. * Update server tests - Default to throughput test in k6 - Use TGI_WIGGLE_ROOM to adjust wiggle room * Only n_heads / process_group.size() are necessary. * Revert the integrationt tests change (seem linked to head_size modification). * Adding error message when assert is violated. * Fixing the free algorithm to handle times where the common prefix is smaller. * Apply suggestions from code review Co-authored-by: OlivierDehaene <olivier@huggingface.co> * Update server/text_generation_server/layers/attention/common.py Co-authored-by: OlivierDehaene <olivier@huggingface.co> * Fix disabling prefix caching - Fix windowing checks. * Revert the Cohere tokenizer change (for now using a revision instead). * Fmt. --------- Co-authored-by: drbh <david.richard.holtz@gmail.com> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
238 lines
7.4 KiB
Rust
238 lines
7.4 KiB
Rust
use std::time::{Duration, Instant};
|
|
use text_generation_client::v3::{
|
|
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
|
|
StoppingCriteriaParameters,
|
|
};
|
|
use text_generation_client::{Chunk, ClientError, Input};
|
|
use tokenizers::{Tokenizer, TruncationDirection};
|
|
use tokio::sync::{broadcast, mpsc};
|
|
|
|
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct Prefill {
|
|
pub(crate) latency: Duration,
|
|
pub(crate) throughput: f64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct Decode {
|
|
pub(crate) latency: Duration,
|
|
pub(crate) token_latency: Duration,
|
|
pub(crate) throughput: f64,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) enum Message {
|
|
Warmup,
|
|
Prefill(Prefill),
|
|
Decode(Decode),
|
|
EndRun,
|
|
EndBatch,
|
|
}
|
|
|
|
/// Benchmarking task
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub(crate) async fn generation_task(
|
|
tokenizer: Tokenizer,
|
|
batch_size: Vec<u32>,
|
|
sequence_length: u32,
|
|
decode_length: u32,
|
|
top_n_tokens: Option<u32>,
|
|
n_runs: usize,
|
|
warmups: usize,
|
|
parameters: NextTokenChooserParameters,
|
|
client: ShardedClient,
|
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
|
) {
|
|
// End task if a message is received on shutdown_receiver
|
|
// _shutdown_guard_sender will be dropped once the task is finished
|
|
tokio::select! {
|
|
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
|
if let Err(err) = res {
|
|
run_sender.send(Err(err)).await.unwrap_or(());
|
|
}
|
|
},
|
|
_ = shutdown_receiver.recv() => {}
|
|
}
|
|
}
|
|
|
|
/// Benchmark prefill/decode
|
|
#[allow(clippy::too_many_arguments)]
|
|
async fn generate_runs(
|
|
tokenizer: Tokenizer,
|
|
batch_size: Vec<u32>,
|
|
sequence_length: u32,
|
|
decode_length: u32,
|
|
top_n_tokens: Option<u32>,
|
|
n_runs: usize,
|
|
warmups: usize,
|
|
parameters: NextTokenChooserParameters,
|
|
mut client: ShardedClient,
|
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
|
) -> Result<(), ClientError> {
|
|
// Create a dummy sequence
|
|
let sequence = create_sequence(sequence_length, tokenizer);
|
|
|
|
for b in batch_size {
|
|
// Warmups on batch size
|
|
for _ in 0..warmups {
|
|
let (_, decode_batch) = prefill(
|
|
sequence.clone(),
|
|
sequence_length,
|
|
b,
|
|
decode_length,
|
|
parameters.clone(),
|
|
top_n_tokens,
|
|
&mut client,
|
|
)
|
|
.await?;
|
|
let _ = decode(decode_batch, &mut client).await?;
|
|
// Send warmup message
|
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
|
}
|
|
|
|
for _ in 0..n_runs {
|
|
let (prefill, decode_batch) = prefill(
|
|
sequence.clone(),
|
|
sequence_length,
|
|
b,
|
|
decode_length,
|
|
parameters.clone(),
|
|
top_n_tokens,
|
|
&mut client,
|
|
)
|
|
.await?;
|
|
// Send prefill message
|
|
run_sender
|
|
.send(Ok(Message::Prefill(prefill)))
|
|
.await
|
|
.unwrap_or(());
|
|
|
|
let decode = decode(decode_batch, &mut client).await?;
|
|
|
|
// Send decode message
|
|
run_sender
|
|
.send(Ok(Message::Decode(decode)))
|
|
.await
|
|
.unwrap_or(());
|
|
|
|
// Send run ended message
|
|
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
|
|
}
|
|
// Batch ended
|
|
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// Run a prefill step
|
|
async fn prefill(
|
|
sequence: String,
|
|
sequence_length: u32,
|
|
batch_size: u32,
|
|
decode_length: u32,
|
|
parameters: NextTokenChooserParameters,
|
|
top_n_tokens: Option<u32>,
|
|
client: &mut ShardedClient,
|
|
) -> Result<(Prefill, CachedBatch), ClientError> {
|
|
// Create requests
|
|
let requests = (0..batch_size)
|
|
.map(|id| Request {
|
|
id: id.into(),
|
|
prefill_logprobs: false,
|
|
input_chunks: Some(Input {
|
|
chunks: vec![Chunk::Text(sequence.clone()).into()],
|
|
}),
|
|
inputs: sequence.clone(),
|
|
truncate: sequence_length,
|
|
add_special_tokens: true,
|
|
parameters: Some(parameters.clone()),
|
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
|
max_new_tokens: decode_length,
|
|
stop_sequences: vec![],
|
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
|
}),
|
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
|
blocks: vec![],
|
|
slots: vec![],
|
|
prefix_len: 0,
|
|
adapter_id: None,
|
|
})
|
|
.collect();
|
|
|
|
let batch = Batch {
|
|
id: 0,
|
|
requests,
|
|
size: batch_size,
|
|
max_tokens: batch_size * (sequence_length + decode_length),
|
|
max_blocks: 0,
|
|
};
|
|
|
|
// Run prefill
|
|
let start_time = Instant::now();
|
|
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
|
|
|
// Get latency
|
|
let latency = start_time.elapsed();
|
|
|
|
// Compute throughput from latency and batch size
|
|
let throughput = batch_size as f64 / latency.as_secs_f64();
|
|
|
|
// Decode batch cannot be empty
|
|
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
|
|
|
let step = Prefill {
|
|
latency,
|
|
throughput,
|
|
};
|
|
|
|
Ok((step, decode_batch))
|
|
}
|
|
|
|
/// Run a full decode
|
|
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
|
let mut decode_length = 0;
|
|
let batch_size = batch.size;
|
|
|
|
let start_time = Instant::now();
|
|
|
|
// Full decode over decode length
|
|
let mut next_batch = Some(batch);
|
|
while let Some(batch) = next_batch {
|
|
let result = client.decode(vec![batch]).await?;
|
|
next_batch = result.1;
|
|
decode_length += 1;
|
|
}
|
|
|
|
// Get latency
|
|
let latency = start_time.elapsed();
|
|
let token_latency = latency / decode_length;
|
|
|
|
// Compute throughput from latency, batch size and decode length
|
|
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
|
|
|
|
let step = Decode {
|
|
latency,
|
|
token_latency,
|
|
throughput,
|
|
};
|
|
Ok(step)
|
|
}
|
|
|
|
/// Create a dummy sequence of the correct length
|
|
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
|
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
|
// Repeat lorem ipsum to cover sequence length
|
|
let string_sequence =
|
|
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
|
// Encode sequence
|
|
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
|
// Truncate to sequence_length
|
|
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
|
// Decode
|
|
tokenizer.decode(encoding.get_ids(), false).unwrap()
|
|
}
|