From 5bfc8631ce1ccc4f877ae7411e675eb87864ee09 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 24 May 2023 18:56:26 +0200 Subject: [PATCH] fix tests --- server/tests/models/test_bloom.py | 6 +++--- server/tests/models/test_causal_lm.py | 6 +++--- server/tests/models/test_seq2seq_lm.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index f0adab97..105b3573 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -286,7 +286,7 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens @@ -309,7 +309,7 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f1f13e4b..d8d1bd16 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -285,7 +285,7 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens @@ -306,7 +306,7 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index ba769e75..8fdeee60 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0]]) + next_batch = next_batch.filter([next_batch.requests[0].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -323,7 +323,7 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None @@ -333,7 +333,7 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter([next_batch.requests[1]]) + next_batch = next_batch.filter([next_batch.requests[1].id]) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None