Fix warmup shapes for corner cases (#136)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-05-06 11:35:27 +02:00 committed by GitHub
parent 4169ff8e6f
commit bad7fe720a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 11 deletions

View File

@ -121,7 +121,12 @@ impl Client {
};
// get all possible prefill batch sizes
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
let max_decode_batch_size: u32 = match max_batch_size {
Some(max_batch_size) => max_batch_size as u32,
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
};
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
@ -142,13 +147,10 @@ impl Client {
}
}
// if max_batch_size is None, create two batches
let num_batches = max_batch_size.unwrap_or(2).min(2);
let mut id_counter: u64 = 0;
for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation
// in case decode bs=1 create one batch
let batches: Vec<Batch> = vec![
let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
*shape,
&mut id_counter,
@ -156,9 +158,22 @@ impl Client {
max_total_tokens,
seq_bucket_size,
false,
);
num_batches
)
];
// if possible, create second batch in order to trigger concatenate operation
if *batch_size < max_decode_batch_size {
batches.push(
self.create_warmup_batch(
(1, *seq_length),
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
)
);
}
let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
@ -168,7 +183,7 @@ impl Client {
let _response = self.stub.warmup(request).await?.into_inner();
}
//Send batches with deafult params to warm up Greedy search
// send batches with default params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
for batch_size in &batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
@ -182,8 +197,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
true,
);
num_batches
)
];
let request = tonic::Request::new(WarmupRequest {
batches,

View File

@ -1114,6 +1114,8 @@ class CausalLM(Model):
# if decode bs is 1 warmup ends here
if len(batches) == 0:
while decode_batch is not None:
_, decode_batch, _ = self.generate_token([decode_batch])
return
# prefill