From 06227f7b5e9d936fc3246f3d649b9e92f14d8559 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Thu, 4 Apr 2024 11:10:11 +0200 Subject: [PATCH] Fix router tests (#119) Co-authored-by: Karol Damaszke --- router/src/queue.rs | 46 +++++++++++++++++++++++++--------------- router/src/validation.rs | 4 ++-- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/router/src/queue.rs b/router/src/queue.rs index 2b5f61b1..8590d82c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -453,6 +453,18 @@ mod tests { use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use tracing::info_span; + fn default_queue() -> Queue { + Queue::new( + true, 1, 2, 1, None + ) + } + + fn default_state() -> State { + State::new( + true, 1, 2, 1, None + ) + } + fn default_entry() -> ( Entry, mpsc::UnboundedReceiver>, @@ -493,7 +505,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false, 1, None); + let mut state = default_state(); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -509,7 +521,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false, 1, None); + let mut state = default_state(); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -517,13 +529,13 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None); + let mut state = default_state(); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -543,19 +555,19 @@ mod tests { assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); + let IdentifiableEntry(id, _) = state.entries.pop().unwrap(); assert_eq!(id, 2); } #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None); + let mut state = default_state(); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); + let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -568,7 +580,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); + let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -582,14 +594,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None); + let queue = default_queue(); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None); + let queue = default_queue(); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -597,13 +609,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None); + let queue = default_queue(); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -620,7 +632,7 @@ mod tests { // Not enough token budget assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -630,13 +642,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None); + let queue = default_queue(); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -645,7 +657,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -655,7 +667,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None); + let queue = default_queue(); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/validation.rs b/router/src/validation.rs index aeaf463a..9afdb0bb 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -95,7 +95,7 @@ impl Validation { // Await on response channel // Unwrap is safe here - let (inputs, input_length) = response_receiver.await.unwrap()?; + let (inputs, _) = response_receiver.await.unwrap()?; let input_length = if self.skip_tokenizer_in_tgi { inputs.chars().filter(|&c| c == ',').count() + 1 @@ -521,7 +521,7 @@ mod tests { .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { - Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 5, 10)) => (), _ => panic!("Unexpected not max new tokens"), } }