mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Fix router tests (#119)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
e210e15e27
commit
06227f7b5e
@ -453,6 +453,18 @@ mod tests {
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_queue() -> Queue {
|
||||
Queue::new(
|
||||
true, 1, 2, 1, None
|
||||
)
|
||||
}
|
||||
|
||||
fn default_state() -> State {
|
||||
State::new(
|
||||
true, 1, 2, 1, None
|
||||
)
|
||||
}
|
||||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||
@ -493,7 +505,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = default_state();
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -509,7 +521,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = default_state();
|
||||
|
||||
assert!(state.next_batch(None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||
@ -517,13 +529,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = default_state();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 2, 4).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -543,19 +555,19 @@ mod tests {
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
let IdentifiableEntry(id, _) = state.entries.pop().unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None);
|
||||
let mut state = default_state();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 1, 2).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -568,7 +580,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, 3, 6).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -582,14 +594,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = default_queue();
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = default_queue();
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||
@ -597,13 +609,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = default_queue();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -620,7 +632,7 @@ mod tests {
|
||||
// Not enough token budget
|
||||
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
||||
// Ok
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 1, 2).await.unwrap();
|
||||
assert_eq!(entries2.len(), 1);
|
||||
assert!(entries2.contains_key(&2));
|
||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||
@ -630,13 +642,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = default_queue();
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 1, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -645,7 +657,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, 2, 4).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -655,7 +667,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None);
|
||||
let queue = default_queue();
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -95,7 +95,7 @@ impl Validation {
|
||||
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||
let (inputs, _) = response_receiver.await.unwrap()?;
|
||||
|
||||
let input_length = if self.skip_tokenizer_in_tgi {
|
||||
inputs.chars().filter(|&c| c == ',').count() + 1
|
||||
@ -521,7 +521,7 @@ mod tests {
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
Err(ValidationError::MaxTotalTokens(6, 5, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user