fix tests

This commit is contained in:
OlivierDehaene 2023-05-24 17:02:20 +02:00
parent a62f14872e
commit caa9608347
5 changed files with 21 additions and 9 deletions

View File

@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)

View File

@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):

View File

@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch

View File

@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):

View File

@ -393,7 +393,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)