Add bucket for input seq len exactly same as --max-input-length (#178)

This commit is contained in:
Sun Choi 2024-07-05 01:30:26 -07:00 committed by GitHub
parent 1b4d80c03e
commit fff1d4f86f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -137,8 +137,8 @@ impl Client {
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect(); let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
if let Some(&last) = seq_lengths.last() { if let Some(&last) = seq_lengths.last() {
if last < max_input_length { if last < (max_input_length + 1) {
seq_lengths.push(max_input_length); seq_lengths.push(max_input_length + 1);
} }
} }