mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Set unique request id during warmup (#170)
This commit is contained in:
parent
4b4382c6f8
commit
535a35db17
@ -150,13 +150,15 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut id_counter: u64 = 0;
|
let mut batch_counter: u64 = 0;
|
||||||
|
let mut request_counter: u64 = 0;
|
||||||
for shape in shapes.iter() {
|
for shape in shapes.iter() {
|
||||||
let (batch_size, seq_length) = shape;
|
let (batch_size, seq_length) = shape;
|
||||||
let mut batches: Vec<Batch> = vec![
|
let mut batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
*shape,
|
*shape,
|
||||||
&mut id_counter,
|
&mut batch_counter,
|
||||||
|
&mut request_counter,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
@ -169,7 +171,8 @@ impl Client {
|
|||||||
batches.push(
|
batches.push(
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
(1, *seq_length),
|
(1, *seq_length),
|
||||||
&mut id_counter,
|
&mut batch_counter,
|
||||||
|
&mut request_counter,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
@ -201,7 +204,8 @@ impl Client {
|
|||||||
let mut batches: Vec<Batch> = vec![
|
let mut batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
(requests_send, seq_bucket_size),
|
(requests_send, seq_bucket_size),
|
||||||
&mut id_counter,
|
&mut batch_counter,
|
||||||
|
&mut request_counter,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
@ -225,7 +229,8 @@ impl Client {
|
|||||||
batches.push(
|
batches.push(
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
(num_requests, seq_bucket_size),
|
(num_requests, seq_bucket_size),
|
||||||
&mut id_counter,
|
&mut batch_counter,
|
||||||
|
&mut request_counter,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
@ -256,7 +261,8 @@ impl Client {
|
|||||||
let batches: Vec<Batch> = vec![
|
let batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
*greedy_shape,
|
*greedy_shape,
|
||||||
&mut id_counter,
|
&mut batch_counter,
|
||||||
|
&mut request_counter,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
@ -279,17 +285,19 @@ impl Client {
|
|||||||
fn create_warmup_batch(
|
fn create_warmup_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
shape: (u32, u32),
|
shape: (u32, u32),
|
||||||
id_counter: &mut u64,
|
batch_counter: &mut u64,
|
||||||
|
request_counter: &mut u64,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
seq_bucket_size: u32,
|
seq_bucket_size: u32,
|
||||||
default_params: bool,
|
default_params: bool,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Batch {
|
) -> Batch {
|
||||||
*id_counter += 1;
|
*batch_counter += 1;
|
||||||
let (batch_size, input_length) = shape;
|
let (batch_size, input_length) = shape;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
for request_id in 0..batch_size {
|
for _ in 0..batch_size {
|
||||||
|
*request_counter += 1;
|
||||||
let req_params = if default_params {
|
let req_params = if default_params {
|
||||||
Some(NextTokenChooserParameters {
|
Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -320,7 +328,7 @@ impl Client {
|
|||||||
})
|
})
|
||||||
};
|
};
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: *id_counter + request_id as u64,
|
id: *request_counter,
|
||||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||||
truncate: max_input_length,
|
truncate: max_input_length,
|
||||||
parameters: req_params,
|
parameters: req_params,
|
||||||
@ -335,7 +343,7 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Batch {
|
Batch {
|
||||||
id: *id_counter,
|
id: *batch_counter,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_total_tokens,
|
max_tokens: max_total_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user