Adjust warmup to all possible bucket sizes and decode batch size = 1 (#113)

This commit is contained in:
jkaniecki 2024-03-27 11:59:51 +01:00 committed by GitHub
parent 9796b0e10d
commit 56f00a552b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 13 deletions

View File

@ -106,6 +106,7 @@ impl Client {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: Option<u32>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true"); let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
if !warmup_enabled { if !warmup_enabled {
@ -123,7 +124,12 @@ impl Client {
// get all possible sequence lengths for prefill // get all possible sequence lengths for prefill
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
let seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
if let Some(&last) = seq_lengths.last() {
if last < max_input_length {
seq_lengths.push(max_input_length);
}
}
// execute batch for each combination of batch size and sequence length // execute batch for each combination of batch size and sequence length
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
@ -134,11 +140,29 @@ impl Client {
} }
let mut id_counter: u64 = 0; let mut id_counter: u64 = 0;
let num_batches = match max_batch_total_tokens {
Some(val) => {
if val == max_total_tokens {
1
} else {
2
}
}
None => 2, // If max_batch_total_tokens is None, create two batches
};
for shape in shapes.iter() { for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation // create two batches in order to trigger concatenate operation
// in case decode bs=1 create one batch
let batches: Vec<Batch> = vec![ let batches: Vec<Batch> = vec![
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false), self.create_warmup_batch(
self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false) *shape,
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
);
num_batches
]; ];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner(); let _response = self.stub.warmup(request).await?.into_inner();
@ -151,8 +175,15 @@ impl Client {
} }
for greedy_shape in greedy_shapes.iter() { for greedy_shape in greedy_shapes.iter() {
let batches: Vec<Batch> = vec![ let batches: Vec<Batch> = vec![
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), self.create_warmup_batch(
self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), *greedy_shape,
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
true,
);
num_batches
]; ];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
let _response = self.stub.warmup(request).await?.into_inner(); let _response = self.stub.warmup(request).await?.into_inner();
@ -233,14 +264,21 @@ impl Client {
// generate random tokens // generate random tokens
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let range = Uniform::new(2, 8192); let range = Uniform::new(2, 8192);
let tokens = input_length - seq_bucket_size / 2; let tokens = if input_length % seq_bucket_size == 0 {
input_length - seq_bucket_size / 2
} else {
input_length - (input_length % seq_bucket_size) / 2
};
(0..tokens) (0..tokens)
.map(|_| rng.sample(&range).to_string()) .map(|_| rng.sample(&range).to_string())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", ") .join(", ")
} else { } else {
// repeat test string to get expected input shape // repeat test string to get expected input shape
let bucket_id = input_length / seq_bucket_size; let mut bucket_id = input_length / seq_bucket_size;
if input_length % seq_bucket_size != 0 {
bucket_id += 1
}
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2); let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
"_test ".to_string().repeat(repeats as usize) "_test ".to_string().repeat(repeats as usize)
} }

View File

@ -96,12 +96,13 @@ impl ShardedClient {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: Option<u32>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens, max_batch_total_tokens))
}) })
.collect(); .collect();
// Take the minimum value // Take the minimum value

View File

@ -220,7 +220,7 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32) .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32, max_batch_total_tokens)
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?
{ {

View File

@ -461,7 +461,11 @@ class CausalLMBatch(Batch):
left_padding = max_input_length - input_len left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
bucket_size = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - 1 rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
bucket_size = max_input_length - 1
left_padding = bucket_size - input_len left_padding = bucket_size - input_len
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
@ -1049,15 +1053,17 @@ class CausalLM(Model):
return generations, batch if not stopped else None return generations, batch if not stopped else None
def warmup(self, batches: List[CausalLMBatch]) -> None: def warmup(self, batches: List[CausalLMBatch]) -> None:
if len(batches) < 2:
return
# prefill # prefill
_, prefill_batch = self.generate_token([batches.pop(0)]) _, prefill_batch = self.generate_token([batches.pop(0)])
# decode # decode
_, decode_batch = self.generate_token([prefill_batch]) _, decode_batch = self.generate_token([prefill_batch])
# shifts # shifts
self.shifting_warmup(decode_batch) self.shifting_warmup(decode_batch)
# if decode bs is 1 warmup ends here
if len(batches) == 0:
return
# prefill # prefill
_, prefill_batch = self.generate_token([batches.pop(0)]) _, prefill_batch = self.generate_token([batches.pop(0)])
# concatenate and decode # concatenate and decode