From 0e8f8726db3f74bcb0093a5dd7cc1665067c16de Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Wed, 29 May 2024 22:46:55 +0200 Subject: [PATCH] Warmup all decode buckets (#152) Co-authored-by: Karol Damaszke --- router/client/src/client.rs | 71 ++++++++++++++++--- .../models/causal_lm.py | 30 +++++--- 2 files changed, 82 insertions(+), 19 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 67ab870d..142aae4e 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -120,15 +120,18 @@ impl Client { env::var(key).ok().map_or(default, |value| value.parse::().unwrap()) }; - // get all possible prefill batch sizes - let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; + // get all possible batch sizes + let decode_bucket_size: u32 = read_env_var("BATCH_BUCKET_SIZE", 8); let max_decode_batch_size: u32 = match max_batch_size { Some(max_batch_size) => max_batch_size as u32, - None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8) + None => decode_bucket_size }; - max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size); + let decode_batch_sizes: Vec = (decode_bucket_size..max_decode_batch_size+1).step_by(decode_bucket_size as usize).collect(); + let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4); - let batch_sizes: Vec = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect(); + let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length; + max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size); + let prefill_batch_sizes: Vec = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect(); // get all possible sequence lengths for prefill let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); @@ -140,8 +143,8 @@ impl Client { } // execute batch for each combination of batch size and sequence length - let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); - for batch_size in &batch_sizes { + let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len() * seq_lengths.len()); + for batch_size in &prefill_batch_sizes { for seq_length in &seq_lengths { shapes.push((*batch_size, *seq_length)); } @@ -183,9 +186,59 @@ impl Client { let _response = self.stub.warmup(request).await?.into_inner(); } + // send batches to warmup all possible decode shapes + if decode_batch_sizes.len() > 1 { + let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size); + let mut batches: Vec = vec![ + self.create_warmup_batch( + (requests_send, seq_bucket_size), + &mut id_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + ) + ]; + + let get_current_decode_batch_size = |num: u32| -> u32 { + decode_batch_sizes.iter() + .filter(|&&x| x >= num) + .min() + .copied() + .unwrap() + }; + + let mut current_decode_batch_size: u32 = get_current_decode_batch_size(requests_send); + while current_decode_batch_size < max_decode_batch_size { + let distance_to_next_bucket = current_decode_batch_size + decode_bucket_size - requests_send; + let num_requests: u32 = cmp::min(distance_to_next_bucket, max_prefill_batch_size); + batches.push( + self.create_warmup_batch( + (num_requests, seq_bucket_size), + &mut id_counter, + max_input_length, + max_total_tokens, + seq_bucket_size, + false, + ) + ); + + requests_send += num_requests; + current_decode_batch_size = get_current_decode_batch_size(requests_send); + } + + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); + } + // send batches with default params to warm up Greedy search - let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len()); - for batch_size in &batch_sizes { + let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(prefill_batch_sizes.len()); + for batch_size in &prefill_batch_sizes { greedy_shapes.push((*batch_size, seq_bucket_size.clone())); } for greedy_shape in greedy_shapes.iter() { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cc8ddde8..bed8bcdd 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1124,6 +1124,12 @@ class CausalLM(Model): return generations, batch if not stopped else None, (forward_ns, decode_ns) def warmup(self, batches: List[CausalLMBatch]) -> None: + def get_unfinished_requests(requests: List[CausalLMRequest]) -> List[int]: + return [ + request.data.id for request in requests + if request.stopping_criteria.current_tokens < request.stopping_criteria.max_new_tokens + ] + # prefill _, prefill_batch, _ = self.generate_token([batches.pop(0)]) # decode @@ -1131,18 +1137,22 @@ class CausalLM(Model): # shifts self.shifting_warmup(decode_batch) - # if decode bs is 1 warmup ends here - if len(batches) == 0: - while decode_batch is not None: - _, decode_batch, _ = self.generate_token([decode_batch]) - return + while len(batches) > 0: + # prefill + _, prefill_batch, _ = self.generate_token([batches.pop(0)]) + # concatenate and decode + _, decode_batch, _ = self.generate_token([decode_batch, prefill_batch]) + # filter finished requests + request_ids = get_unfinished_requests(decode_batch.requests) + if len(request_ids) < len(decode_batch.requests): + decode_batch = decode_batch.filter(request_ids) - # prefill - _, prefill_batch, _ = self.generate_token([batches.pop(0)]) - # concatenate and decode - _, decode_batch, _ = self.generate_token([decode_batch, prefill_batch]) - # decodes while decode_batch is not None: + # filter finished requests + request_ids = get_unfinished_requests(decode_batch.requests) + if len(request_ids) < len(decode_batch.requests): + decode_batch = decode_batch.filter(request_ids) + # decode _, decode_batch, _ = self.generate_token([decode_batch]) def shifting_warmup(self, batch: CausalLMBatch) -> None: