mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix tests
This commit is contained in:
parent
6012976445
commit
5bfc8631ce
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
# Copy stopping_criterias before filtering
|
# Copy stopping_criterias before filtering
|
||||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
@ -286,7 +286,7 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
@ -309,7 +309,7 @@ def test_batch_concatenate(
|
|||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
@ -285,7 +285,7 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
@ -306,7 +306,7 @@ def test_batch_concatenate(
|
|||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
)
|
)
|
||||||
assert generations[1].generated_text.generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
@ -323,7 +323,7 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
@ -333,7 +333,7 @@ def test_batch_concatenate(
|
|||||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generations[0].generated_text.generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
Loading…
Reference in New Issue
Block a user