From 2b1581edac3ce6a774e9a7f051539345014974cb Mon Sep 17 00:00:00 2001 From: jkaniecki <153085639+jkaniecki@users.noreply.github.com> Date: Fri, 22 Mar 2024 23:43:20 +0100 Subject: [PATCH] Warmup greedy search in next token chooser (#109) --- router/client/src/client.rs | 44 ++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 9aefaf55..1823fca3 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -137,13 +137,26 @@ impl Client { for shape in shapes.iter() { // create two batches in order to trigger concatenate operation let batches: Vec = vec![ - self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size), - self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size) + self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false), + self.create_warmup_batch(*shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, false) ]; let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let _response = self.stub.warmup(request).await?.into_inner(); } + //Send batches with deafult params to warm up Greedy search + let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len()); + for batch_size in &batch_sizes { + greedy_shapes.push((*batch_size, seq_bucket_size.clone())); + } + for greedy_shape in greedy_shapes.iter() { + let batches: Vec = vec![ + self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), + self.create_warmup_batch(*greedy_shape, &mut id_counter, max_input_length, max_total_tokens, seq_bucket_size, true), + ]; + let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); + let _response = self.stub.warmup(request).await?.into_inner(); + } Ok(None) // No support for maximum total tokens } @@ -155,16 +168,25 @@ impl Client { max_input_length: u32, max_total_tokens: u32, seq_bucket_size: u32, + default_params: bool, ) -> Batch { *id_counter += 1; let (batch_size, input_length) = shape; let mut requests = Vec::new(); for request_id in 0..batch_size { - requests.push(Request { - id: *id_counter + request_id as u64, - inputs: self.get_random_input(input_length, seq_bucket_size), - truncate: max_input_length, - parameters: Some(NextTokenChooserParameters { + let req_params = if default_params { + Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + watermark: false, + }) + } else { + Some(NextTokenChooserParameters { temperature: 0.9, top_k: 10, top_p: 0.9, @@ -173,7 +195,13 @@ impl Client { seed: 0, repetition_penalty: 1.2, watermark: true, - }), + }) + }; + requests.push(Request { + id: *id_counter + request_id as u64, + inputs: self.get_random_input(input_length, seq_bucket_size), + truncate: max_input_length, + parameters: req_params, stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: cmp::min(10, max_total_tokens - max_input_length), stop_sequences: vec![],