Fix router tests (#119)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-04-04 11:10:11 +02:00 committed by GitHub
parent e210e15e27
commit 06227f7b5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 19 deletions

View File

@ -453,6 +453,18 @@ mod tests {
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tracing::info_span; use tracing::info_span;
fn default_queue() -> Queue {
Queue::new(
true, 1, 2, 1, None
)
}
fn default_state() -> State {
State::new(
true, 1, 2, 1, None
)
}
fn default_entry() -> ( fn default_entry() -> (
Entry, Entry,
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>, mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
@ -493,7 +505,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(false, 1, None); let mut state = default_state();
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -509,7 +521,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false, 1, None); let mut state = default_state();
assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -517,13 +529,13 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None); let mut state = default_state();
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -543,19 +555,19 @@ mod tests {
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap(); let IdentifiableEntry(id, _) = state.entries.pop().unwrap();
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None); let mut state = default_state();
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -568,7 +580,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -582,14 +594,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None); let queue = default_queue();
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None); let queue = default_queue();
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -597,13 +609,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None); let queue = default_queue();
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -620,7 +632,7 @@ mod tests {
// Not enough token budget // Not enough token budget
assert!(queue.next_batch(Some(1), 0, 0).await.is_none()); assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
// Ok // Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap(); let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap();
assert_eq!(entries2.len(), 1); assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2)); assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some()); assert!(entries2.get(&2).unwrap().batch_time.is_some());
@ -630,13 +642,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None); let queue = default_queue();
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -645,7 +657,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
queue.append(entry3); queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -655,7 +667,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None); let queue = default_queue();
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -95,7 +95,7 @@ impl Validation {
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?; let (inputs, _) = response_receiver.await.unwrap()?;
let input_length = if self.skip_tokenizer_in_tgi { let input_length = if self.skip_tokenizer_in_tgi {
inputs.chars().filter(|&c| c == ',').count() + 1 inputs.chars().filter(|&c| c == ',').count() + 1
@ -521,7 +521,7 @@ mod tests {
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 5, 10)) => (),
_ => panic!("Unexpected not max new tokens"), _ => panic!("Unexpected not max new tokens"),
} }
} }