fix tests

This commit is contained in:
OlivierDehaene 2024-06-12 18:54:25 +02:00
parent 05eb4dcb17
commit abe521204e
3 changed files with 12 additions and 12 deletions

View File

@ -199,7 +199,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_bloom, default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[], [],
) )
@ -312,8 +312,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_bloom, default_bloom,
[ [
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
], ],
[], [],
) )
@ -341,7 +341,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_bloom, default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
[], [],
) )

View File

@ -200,7 +200,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_causal_lm, default_causal_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[], [],
) )
@ -312,8 +312,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_causal_lm, default_causal_lm,
[ [
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
], ],
[], [],
) )
@ -340,7 +340,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_causal_lm, default_causal_lm,
[ [
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
], ],
[], [],
) )

View File

@ -208,7 +208,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_seq2seq_lm, default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])],
[], [],
) )
@ -346,8 +346,8 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_seq2seq_lm, default_seq2seq_lm,
[ [
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]),
], ],
[], [],
) )
@ -362,7 +362,7 @@ def test_batch_concatenate(
next_batch, _ = next_batch.filter( next_batch, _ = next_batch.filter(
default_seq2seq_lm, default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])],
[], [],
) )