mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Set unique request id during warmup (#170)
This commit is contained in:
parent
4b4382c6f8
commit
535a35db17
@ -150,13 +150,15 @@ impl Client {
|
||||
}
|
||||
}
|
||||
|
||||
let mut id_counter: u64 = 0;
|
||||
let mut batch_counter: u64 = 0;
|
||||
let mut request_counter: u64 = 0;
|
||||
for shape in shapes.iter() {
|
||||
let (batch_size, seq_length) = shape;
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
*shape,
|
||||
&mut id_counter,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
@ -169,7 +171,8 @@ impl Client {
|
||||
batches.push(
|
||||
self.create_warmup_batch(
|
||||
(1, *seq_length),
|
||||
&mut id_counter,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
@ -201,7 +204,8 @@ impl Client {
|
||||
let mut batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
(requests_send, seq_bucket_size),
|
||||
&mut id_counter,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
@ -225,7 +229,8 @@ impl Client {
|
||||
batches.push(
|
||||
self.create_warmup_batch(
|
||||
(num_requests, seq_bucket_size),
|
||||
&mut id_counter,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
@ -256,7 +261,8 @@ impl Client {
|
||||
let batches: Vec<Batch> = vec![
|
||||
self.create_warmup_batch(
|
||||
*greedy_shape,
|
||||
&mut id_counter,
|
||||
&mut batch_counter,
|
||||
&mut request_counter,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
seq_bucket_size,
|
||||
@ -279,17 +285,19 @@ impl Client {
|
||||
fn create_warmup_batch(
|
||||
&mut self,
|
||||
shape: (u32, u32),
|
||||
id_counter: &mut u64,
|
||||
batch_counter: &mut u64,
|
||||
request_counter: &mut u64,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
seq_bucket_size: u32,
|
||||
default_params: bool,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Batch {
|
||||
*id_counter += 1;
|
||||
*batch_counter += 1;
|
||||
let (batch_size, input_length) = shape;
|
||||
let mut requests = Vec::new();
|
||||
for request_id in 0..batch_size {
|
||||
for _ in 0..batch_size {
|
||||
*request_counter += 1;
|
||||
let req_params = if default_params {
|
||||
Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
@ -320,7 +328,7 @@ impl Client {
|
||||
})
|
||||
};
|
||||
requests.push(Request {
|
||||
id: *id_counter + request_id as u64,
|
||||
id: *request_counter,
|
||||
inputs: self.get_random_input(input_length, seq_bucket_size),
|
||||
truncate: max_input_length,
|
||||
parameters: req_params,
|
||||
@ -335,7 +343,7 @@ impl Client {
|
||||
}
|
||||
|
||||
Batch {
|
||||
id: *id_counter,
|
||||
id: *batch_counter,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: max_total_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user