fix cargo tests

This commit is contained in:
OlivierDehaene 2024-10-10 16:54:42 +02:00
parent f923a3fb68
commit df98299919
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -339,15 +339,23 @@ impl State {
let postfix_len = entry.request.input_length - block_allocation.prefix_len; let postfix_len = entry.request.input_length - block_allocation.prefix_len;
// Check equality too as if we don't we might end up with a postfix_len = 0 if prefill_tokens + postfix_len > prefill_token_budget {
// in the next iteration of the loop
if prefill_tokens + postfix_len >= prefill_token_budget {
// Entry is over budget // Entry is over budget
if self.support_chunking { if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget // We support chunking, just set postfix_len to exactly match prefill_token_budget
let chunk_len = prefill_token_budget - prefill_tokens; let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
// Push this entry inside the batch if chunk_len > 0 {
batch.push((id, entry, Some(block_allocation), Some(chunk_len))); // Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
self.entries.push_front((id, entry));
}
tracing::debug!(
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
prefill_tokens + postfix_len
);
break 'entry_loop; break 'entry_loop;
} else { } else {
// We don't support chunking, this entry needs to go back to the buffer // We don't support chunking, this entry needs to go back to the buffer
@ -658,7 +666,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2, false); let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -780,7 +788,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16, false); let queue = Queue::new(true, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);