mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Fix router tests (#119)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
e210e15e27
commit
06227f7b5e
@ -453,6 +453,18 @@ mod tests {
|
|||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
|
fn default_queue() -> Queue {
|
||||||
|
Queue::new(
|
||||||
|
true, 1, 2, 1, None
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_state() -> State {
|
||||||
|
State::new(
|
||||||
|
true, 1, 2, 1, None
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
Entry,
|
Entry,
|
||||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||||
@ -493,7 +505,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = default_state();
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -509,7 +521,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = default_state();
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
@ -517,13 +529,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = default_state();
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -543,19 +555,19 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
let IdentifiableEntry(id, _) = state.entries.pop().unwrap();
|
||||||
assert_eq!(id, 2);
|
assert_eq!(id, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, None);
|
let mut state = default_state();
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -568,7 +580,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -582,14 +594,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = default_queue();
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = default_queue();
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
@ -597,13 +609,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = default_queue();
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -620,7 +632,7 @@ mod tests {
|
|||||||
// Not enough token budget
|
// Not enough token budget
|
||||||
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
||||||
// Ok
|
// Ok
|
||||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
|
let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap();
|
||||||
assert_eq!(entries2.len(), 1);
|
assert_eq!(entries2.len(), 1);
|
||||||
assert!(entries2.contains_key(&2));
|
assert!(entries2.contains_key(&2));
|
||||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
@ -630,13 +642,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = default_queue();
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -645,7 +657,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -655,7 +667,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, None);
|
let queue = default_queue();
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ impl Validation {
|
|||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
let (inputs, _) = response_receiver.await.unwrap()?;
|
||||||
|
|
||||||
let input_length = if self.skip_tokenizer_in_tgi {
|
let input_length = if self.skip_tokenizer_in_tgi {
|
||||||
inputs.chars().filter(|&c| c == ',').count() + 1
|
inputs.chars().filter(|&c| c == ',').count() + 1
|
||||||
@ -521,7 +521,7 @@ mod tests {
|
|||||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 5, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens"),
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user