diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 0b7e0d5f..0c85a406 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -70,7 +70,7 @@ impl Batcher { // Notify the background task that we have a new entry in the database that needs // to be batched - self.shared.batching_task.notify_waiters(); + self.shared.batching_task.notify_one(); // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender @@ -104,10 +104,9 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - let mut waiting_tokens = 0; while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; - waiting_tokens += 1; + let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) @@ -131,11 +130,11 @@ async fn batching_task( if let Some((new_request_ids, new_batch)) = db.next_batch(min_size, max_batch_size) { - // Reset waiting counter - waiting_tokens = 0; // Generate one token for this new batch to have the attention past in cache let new_cached_batch = wrap_future(client.generate(new_batch), new_request_ids, &db).await; + // Reset waiting counter + waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));