diff --git a/router/client/src/client.rs b/router/client/src/client.rs index c72926d4..fc49e2a8 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -150,13 +150,15 @@ impl Client { } } - let mut id_counter: u64 = 0; + let mut batch_counter: u64 = 0; + let mut request_counter: u64 = 0; for shape in shapes.iter() { let (batch_size, seq_length) = shape; let mut batches: Vec = vec![ self.create_warmup_batch( *shape, - &mut id_counter, + &mut batch_counter, + &mut request_counter, max_input_length, max_total_tokens, seq_bucket_size, @@ -169,7 +171,8 @@ impl Client { batches.push( self.create_warmup_batch( (1, *seq_length), - &mut id_counter, + &mut batch_counter, + &mut request_counter, max_input_length, max_total_tokens, seq_bucket_size, @@ -201,7 +204,8 @@ impl Client { let mut batches: Vec = vec![ self.create_warmup_batch( (requests_send, seq_bucket_size), - &mut id_counter, + &mut batch_counter, + &mut request_counter, max_input_length, max_total_tokens, seq_bucket_size, @@ -225,7 +229,8 @@ impl Client { batches.push( self.create_warmup_batch( (num_requests, seq_bucket_size), - &mut id_counter, + &mut batch_counter, + &mut request_counter, max_input_length, max_total_tokens, seq_bucket_size, @@ -256,7 +261,8 @@ impl Client { let batches: Vec = vec![ self.create_warmup_batch( *greedy_shape, - &mut id_counter, + &mut batch_counter, + &mut request_counter, max_input_length, max_total_tokens, seq_bucket_size, @@ -279,17 +285,19 @@ impl Client { fn create_warmup_batch( &mut self, shape: (u32, u32), - id_counter: &mut u64, + batch_counter: &mut u64, + request_counter: &mut u64, max_input_length: u32, max_total_tokens: u32, seq_bucket_size: u32, default_params: bool, max_new_tokens: Option, ) -> Batch { - *id_counter += 1; + *batch_counter += 1; let (batch_size, input_length) = shape; let mut requests = Vec::new(); - for request_id in 0..batch_size { + for _ in 0..batch_size { + *request_counter += 1; let req_params = if default_params { Some(NextTokenChooserParameters { temperature: 1.0, @@ -320,7 +328,7 @@ impl Client { }) }; requests.push(Request { - id: *id_counter + request_id as u64, + id: *request_counter, inputs: self.get_random_input(input_length, seq_bucket_size), truncate: max_input_length, parameters: req_params, @@ -335,7 +343,7 @@ impl Client { } Batch { - id: *id_counter, + id: *batch_counter, size: requests.len() as u32, requests, max_tokens: max_total_tokens,