Adjust max_new_tokens in warmup (#160)

This commit is contained in:
Karol Damaszke 2024-06-20 19:48:37 +02:00 committed by GitHub
parent 1033d3b503
commit 4fe871ffaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -161,6 +161,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
false,
None,
)
];
// if possible, create second batch in order to trigger concatenate operation
@ -173,6 +174,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
false,
None,
)
);
}
@ -188,6 +190,13 @@ impl Client {
// send batches to warmup all possible decode shapes
if decode_batch_sizes.len() > 1 {
let steps_per_bucket: u32 = if decode_bucket_size <= max_prefill_batch_size {
decode_bucket_size
} else {
decode_bucket_size.div_ceil(max_prefill_batch_size)
};
let max_new_tokens: u32 = 2 * decode_batch_sizes.len() as u32 * steps_per_bucket;
let mut requests_send: u32 = cmp::min(max_prefill_batch_size, decode_bucket_size);
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
@ -197,6 +206,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
];
@ -220,6 +230,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
false,
Some(max_new_tokens),
)
);
@ -250,6 +261,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
true,
None,
)
];
let request = tonic::Request::new(WarmupRequest {
@ -272,6 +284,7 @@ impl Client {
max_total_tokens: u32,
seq_bucket_size: u32,
default_params: bool,
max_new_tokens: Option<u32>,
) -> Batch {
*id_counter += 1;
let (batch_size, input_length) = shape;
@ -312,7 +325,7 @@ impl Client {
truncate: max_input_length,
parameters: req_params,
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
max_new_tokens: cmp::min(max_new_tokens.unwrap_or(10), max_total_tokens - max_input_length),
stop_sequences: vec![],
ignore_eos_token: true,
}),