mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Adjust max_new_tokens in warmup (#160)
This commit is contained in:
parent
1033d3b503
commit
4fe871ffaa
@ -161,6 +161,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
false,
|
false,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
];
|
];
|
||||||
// if possible, create second batch in order to trigger concatenate operation
|
// if possible, create second batch in order to trigger concatenate operation
|
||||||
@ -173,6 +174,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
false,
|
false,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -188,6 +190,13 @@ impl Client {
|
|||||||
|
|
||||||
// send batches to warmup all possible decode shapes
|
// send batches to warmup all possible decode shapes
|
||||||
if decode_batch_sizes.len() > 1 {
|
if decode_batch_sizes.len() > 1 {
|
||||||
|
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
|
||||||
|
decode_bucket_size
|
||||||
|
} else {
|
||||||
|
decode_bucket_size.div_ceil(max_prefill_batch_size)
|
||||||
|
};
|
||||||
|
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
|
||||||
|
|
||||||
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
|
||||||
let mut batches: Vec<Batch> = vec![
|
let mut batches: Vec<Batch> = vec![
|
||||||
self.create_warmup_batch(
|
self.create_warmup_batch(
|
||||||
@ -197,6 +206,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
false,
|
false,
|
||||||
|
Some(max_new_tokens),
|
||||||
)
|
)
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -220,6 +230,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
false,
|
false,
|
||||||
|
Some(max_new_tokens),
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -250,6 +261,7 @@ impl Client {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
seq_bucket_size,
|
seq_bucket_size,
|
||||||
true,
|
true,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
];
|
];
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
@ -272,6 +284,7 @@ impl Client {
|
|||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
seq_bucket_size: u32,
|
seq_bucket_size: u32,
|
||||||
default_params: bool,
|
default_params: bool,
|
||||||
|
max_new_tokens: Option<u32>,
|
||||||
) -> Batch {
|
) -> Batch {
|
||||||
*id_counter += 1;
|
*id_counter += 1;
|
||||||
let (batch_size, input_length) = shape;
|
let (batch_size, input_length) = shape;
|
||||||
@ -312,7 +325,7 @@ impl Client {
|
|||||||
truncate: max_input_length,
|
truncate: max_input_length,
|
||||||
parameters: req_params,
|
parameters: req_params,
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
max_new_tokens: cmp::min(max_new_tokens.unwrap_or(10), max_total_tokens - max_input_length),
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
|
Loading…
Reference in New Issue
Block a user