From bad7fe720ac13db7766dbdae837a08eed3dc6bd7 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Mon, 6 May 2024 11:35:27 +0200 Subject: [PATCH] Fix warmup shapes for corner cases (#136) Co-authored-by: Karol Damaszke --- router/client/src/client.rs | 36 +++++++++++++------ .../models/causal_lm.py | 2 ++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 51b75a49..67ab870d 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -121,7 +121,12 @@ impl Client { }; // get all possible prefill batch sizes - let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; + let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; + let max_decode_batch_size: u32 = match max_batch_size { + Some(max_batch_size) => max_batch_size as u32, + None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8) + }; + max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size); let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4); let batch_sizes: Vec = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect(); @@ -142,13 +147,10 @@ impl Client { } } - // if max_batch_size is None, create two batches - let num_batches = max_batch_size.unwrap_or(2).min(2); let mut id_counter: u64 = 0; for shape in shapes.iter() { - // create two batches in order to trigger concatenate operation - // in case decode bs=1 create one batch - let batches: Vec = vec![ + let (batch_size, seq_length) = shape; + let mut batches: Vec = vec![ self.create_warmup_batch( *shape, &mut id_counter, @@ -156,9 +158,22 @@ impl Client { max_total_tokens, seq_bucket_size, false, - ); - num_batches + ) ]; + // if possible, create second batch in order to trigger concatenate operation + if *batch_size < max_decode_batch_size { + batches.push( + self.create_warmup_batch( + (1, *seq_length), + &mut id_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + ) + ); + } + let request = tonic::Request::new(WarmupRequest { batches, max_input_length, @@ -168,7 +183,7 @@ impl Client { let _response = self.stub.warmup(request).await?.into_inner(); } - //Send batches with deafult params to warm up Greedy search + // send batches with default params to warm up Greedy search let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len()); for batch_size in &batch_sizes { greedy_shapes.push((*batch_size, seq_bucket_size.clone())); @@ -182,8 +197,7 @@ impl Client { max_total_tokens, seq_bucket_size, true, - ); - num_batches + ) ]; let request = tonic::Request::new(WarmupRequest { batches, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1b03eb3e..65ba35b9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1114,6 +1114,8 @@ class CausalLM(Model): # if decode bs is 1 warmup ends here if len(batches) == 0: + while decode_batch is not None: + _, decode_batch, _ = self.generate_token([decode_batch]) return # prefill