mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix tests
This commit is contained in:
parent
a62f14872e
commit
caa9608347
@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
|
||||
@pytest.fixture
|
||||
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
|
||||
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
|
||||
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
|
@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
|
||||
|
||||
@pytest.fixture
|
||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
||||
return CausalLMBatch.from_pb(
|
||||
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||
return CausalLMBatch.from_pb(
|
||||
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
|
@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
|
||||
@pytest.mark.skip
|
||||
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||
default_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
|
||||
default_santacoder, default_fim_pb_batch
|
||||
):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||
default_fim_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
|
@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
|
||||
@pytest.fixture
|
||||
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||
return Seq2SeqLMBatch.from_pb(
|
||||
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
|
||||
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
||||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||
return Seq2SeqLMBatch.from_pb(
|
||||
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
|
@ -393,7 +393,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user