diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 142aae4e..c72926d4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -161,6 +161,7 @@ impl Client { max_total_tokens, seq_bucket_size, false, + None, ) ]; // if possible, create second batch in order to trigger concatenate operation @@ -173,6 +174,7 @@ impl Client { max_total_tokens, seq_bucket_size, false, + None, ) ); } @@ -188,6 +190,13 @@ impl Client { // send batches to warmup all possible decode shapes if decode_batch_sizes.len() > 1 { + let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size { + decode_bucket_size + } else { + decode_bucket_size.div_ceil(max_prefill_batch_size) + }; + let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket; + let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size); let mut batches: Vec = vec![ self.create_warmup_batch( @@ -197,6 +206,7 @@ impl Client { max_total_tokens, seq_bucket_size, false, + Some(max_new_tokens), ) ]; @@ -220,6 +230,7 @@ impl Client { max_total_tokens, seq_bucket_size, false, + Some(max_new_tokens), ) ); @@ -250,6 +261,7 @@ impl Client { max_total_tokens, seq_bucket_size, true, + None, ) ]; let request = tonic::Request::new(WarmupRequest { @@ -272,6 +284,7 @@ impl Client { max_total_tokens: u32, seq_bucket_size: u32, default_params: bool, + max_new_tokens: Option, ) -> Batch { *id_counter += 1; let (batch_size, input_length) = shape; @@ -312,7 +325,7 @@ impl Client { truncate: max_input_length, parameters: req_params, stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: cmp::min(10, max_total_tokens - max_input_length), + max_new_tokens: cmp::min(max_new_tokens.unwrap_or(10), max_total_tokens - max_input_length), stop_sequences: vec![], ignore_eos_token: true, }),