Adjust warmup to all possible bucket sizes and decode batch size = 1 (#113)

This commit is contained in:
jkaniecki 2024-03-27 11:59:51 +01:00 committed by GitHub
parent 9796b0e10d
commit 56f00a552b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 13 deletions

View File

@ -106,6 +106,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: Option<u32>,
) -> Result<Option<u32>> {
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled {
@ -123,7 +124,12 @@ impl Client {
// get all possible sequence lengths for prefill
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
if let Some(&last) = seq_lengths.last() {
if last < max_input_length {
seq_lengths.push(max_input_length);
}
}
// execute batch for each combination of batch size and sequence length
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
@ -134,11 +140,29 @@ impl Client {
}
let mut id_counter: u64 = 0;
let num_batches = match max_batch_total_tokens {
Some(val) => {
if val == max_total_tokens {
1
} else {
2
}
}
None => 2, // If max_batch_total_tokens is None, create two batches
};
for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation
// in case decode bs=1 create one batch
let batches: Vec<Batch> = vec![
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false),
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false)
self.create_warmup_batch(
*shape,
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
);
num_batches
];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
@ -151,8 +175,15 @@ impl Client {
}
for greedy_shape in greedy_shapes.iter() {
let batches: Vec<Batch> = vec![
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true),
self.create_warmup_batch(
*greedy_shape,
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
true,
);
num_batches
];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
@ -233,14 +264,21 @@ impl Client {
// generate random tokens
let mut rng = rand::thread_rng();
let range = Uniform::new(2, 8192);
let tokens = input_length - seq_bucket_size / 2;
let tokens = if input_length % seq_bucket_size == 0 {
input_length - seq_bucket_size / 2
} else {
input_length - (input_length % seq_bucket_size) / 2
};
(0..tokens)
.map(|_| rng.sample(&range).to_string())
.collect::<Vec<String>>()
.join(", ")
} else {
// repeat test string to get expected input shape
let bucket_id = input_length / seq_bucket_size;
let mut bucket_id = input_length / seq_bucket_size;
if input_length % seq_bucket_size != 0 {
bucket_id += 1
}
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
"_test ".to_string().repeat(repeats as usize)
}

View File

@ -96,12 +96,13 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: Option<u32>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens))
})
.collect();
// Take the minimum value

View File

@ -220,7 +220,7 @@ fn main() -> Result<(), RouterError> {
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens)
.await
.map_err(RouterError::Warmup)?
{

View File

@ -461,7 +461,11 @@ class CausalLMBatch(Batch):
left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
bucket_size = max_input_length - 1
left_padding = bucket_size - input_len
input_ids = tokenized_inputs["input_ids"]
@ -1049,15 +1053,17 @@ class CausalLM(Model):
return generations, batch if not stopped else None
def warmup(self, batches: List[CausalLMBatch]) -> None:
if len(batches) < 2:
return
# prefill
_, prefill_batch = self.generate_token([batches.pop(0)])
# decode
_, decode_batch = self.generate_token([prefill_batch])
# shifts
self.shifting_warmup(decode_batch)
# if decode bs is 1 warmup ends here
if len(batches) == 0:
return
# prefill
_, prefill_batch = self.generate_token([batches.pop(0)])
# concatenate and decode