Set unique request id during warmup (#170)

This commit is contained in:
Karol Damaszke 2024-07-03 10:58:20 +02:00 committed by GitHub
parent 4b4382c6f8
commit 535a35db17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -150,13 +150,15 @@ impl Client {
} }
} }
let mut id_counter: u64 = 0; let mut batch_counter: u64 = 0;
let mut request_counter: u64 = 0;
for shape in shapes.iter() { for shape in shapes.iter() {
let (batch_size, seq_length) = shape; let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![ let mut batches: Vec<Batch> = vec![
self.create_warmup_batch( self.create_warmup_batch(
*shape, *shape,
&mut id_counter, &mut batch_counter,
&mut request_counter,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
seq_bucket_size, seq_bucket_size,
@ -169,7 +171,8 @@ impl Client {
batches.push( batches.push(
self.create_warmup_batch( self.create_warmup_batch(
(1, *seq_length), (1, *seq_length),
&mut id_counter, &mut batch_counter,
&mut request_counter,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
seq_bucket_size, seq_bucket_size,
@ -201,7 +204,8 @@ impl Client {
let mut batches: Vec<Batch> = vec![ let mut batches: Vec<Batch> = vec![
self.create_warmup_batch( self.create_warmup_batch(
(requests_send, seq_bucket_size), (requests_send, seq_bucket_size),
&mut id_counter, &mut batch_counter,
&mut request_counter,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
seq_bucket_size, seq_bucket_size,
@ -225,7 +229,8 @@ impl Client {
batches.push( batches.push(
self.create_warmup_batch( self.create_warmup_batch(
(num_requests, seq_bucket_size), (num_requests, seq_bucket_size),
&mut id_counter, &mut batch_counter,
&mut request_counter,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
seq_bucket_size, seq_bucket_size,
@ -256,7 +261,8 @@ impl Client {
let batches: Vec<Batch> = vec![ let batches: Vec<Batch> = vec![
self.create_warmup_batch( self.create_warmup_batch(
*greedy_shape, *greedy_shape,
&mut id_counter, &mut batch_counter,
&mut request_counter,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
seq_bucket_size, seq_bucket_size,
@ -279,17 +285,19 @@ impl Client {
fn create_warmup_batch( fn create_warmup_batch(
&mut self, &mut self,
shape: (u32, u32), shape: (u32, u32),
id_counter: &mut u64, batch_counter: &mut u64,
request_counter: &mut u64,
max_input_length: u32, max_input_length: u32,
max_total_tokens: u32, max_total_tokens: u32,
seq_bucket_size: u32, seq_bucket_size: u32,
default_params: bool, default_params: bool,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Batch { ) -> Batch {
*id_counter += 1; *batch_counter += 1;
let (batch_size, input_length) = shape; let (batch_size, input_length) = shape;
let mut requests = Vec::new(); let mut requests = Vec::new();
for request_id in 0..batch_size { for _ in 0..batch_size {
*request_counter += 1;
let req_params = if default_params { let req_params = if default_params {
Some(NextTokenChooserParameters { Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
@ -320,7 +328,7 @@ impl Client {
}) })
}; };
requests.push(Request { requests.push(Request {
id: *id_counter + request_id as u64, id: *request_counter,
inputs: self.get_random_input(input_length, seq_bucket_size), inputs: self.get_random_input(input_length, seq_bucket_size),
truncate: max_input_length, truncate: max_input_length,
parameters: req_params, parameters: req_params,
@ -335,7 +343,7 @@ impl Client {
} }
Batch { Batch {
id: *id_counter, id: *batch_counter,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: max_total_tokens, max_tokens: max_total_tokens,