Warmup all decode buckets (#152)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-05-29 22:46:55 +02:00 committed by GitHub
parent 7b879fd1d8
commit 0e8f8726db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 19 deletions

View File

@ -120,15 +120,18 @@ impl Client {
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
};
// get all possible prefill batch sizes
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
// get all possible batch sizes
let decode_bucket_size: u32 = read_env_var("BATCH_BUCKET_SIZE", 8);
let max_decode_batch_size: u32 = match max_batch_size {
Some(max_batch_size) => max_batch_size as u32,
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
None => decode_bucket_size
};
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
let decode_batch_sizes: Vec<u32> = (decode_bucket_size..max_decode_batch_size+1).step_by(decode_bucket_size as usize).collect();
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
let prefill_batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
// get all possible sequence lengths for prefill
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
@ -140,8 +143,8 @@ impl Client {
}
// execute batch for each combination of batch size and sequence length
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
for batch_size in &batch_sizes {
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len() * seq_lengths.len());
for batch_size in &prefill_batch_sizes {
for seq_length in &seq_lengths {
shapes.push((*batch_size, *seq_length));
}
@ -183,9 +186,59 @@ impl Client {
let _response = self.stub.warmup(request).await?.into_inner();
}
// send batches to warmup all possible decode shapes
if decode_batch_sizes.len() > 1 {
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
(requests_send, seq_bucket_size),
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
)
];
let get_current_decode_batch_size = |num: u32| -> u32 {
decode_batch_sizes.iter()
.filter(|&&x| x >= num)
.min()
.copied()
.unwrap()
};
let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send);
while current_decode_batch_size < max_decode_batch_size {
let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send;
let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size);
batches.push(
self.create_warmup_batch(
(num_requests, seq_bucket_size),
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
)
);
requests_send += num_requests;
current_decode_batch_size = get_current_decode_batch_size(requests_send);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner();
}
// send batches with default params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
for batch_size in &batch_sizes {
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len());
for batch_size in &prefill_batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
}
for greedy_shape in greedy_shapes.iter() {

View File

@ -1124,6 +1124,12 @@ class CausalLM(Model):
return generations, batch if not stopped else None, (forward_ns, decode_ns)
def warmup(self, batches: List[CausalLMBatch]) -> None:
def get_unfinished_requests(requests: List[CausalLMRequest]) -> List[int]:
return [
request.data.id for request in requests
if request.stopping_criteria.current_tokens < request.stopping_criteria.max_new_tokens
]
# prefill
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
# decode
@ -1131,18 +1137,22 @@ class CausalLM(Model):
# shifts
self.shifting_warmup(decode_batch)
# if decode bs is 1 warmup ends here
if len(batches) == 0:
while decode_batch is not None:
_, decode_batch, _ = self.generate_token([decode_batch])
return
while len(batches) > 0:
# prefill
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
# concatenate and decode
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
# decodes
# filter finished requests
request_ids = get_unfinished_requests(decode_batch.requests)
if len(request_ids) < len(decode_batch.requests):
decode_batch = decode_batch.filter(request_ids)
while decode_batch is not None:
# filter finished requests
request_ids = get_unfinished_requests(decode_batch.requests)
if len(request_ids) < len(decode_batch.requests):
decode_batch = decode_batch.filter(request_ids)
# decode
_, decode_batch, _ = self.generate_token([decode_batch])
def shifting_warmup(self, batch: CausalLMBatch) -> None: