mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix tests
This commit is contained in:
parent
a62f14872e
commit
caa9608347
@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||||
return BloomCausalLMBatch.from_pb(
|
return BloomCausalLMBatch.from_pb(
|
||||||
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
|
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
|||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return BloomCausalLMBatch.from_pb(
|
return BloomCausalLMBatch.from_pb(
|
||||||
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
|
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||||
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(
|
||||||
|
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
|||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(
|
||||||
|
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
|
@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
|
|||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||||
batch = CausalLMBatch.from_pb(
|
batch = CausalLMBatch.from_pb(
|
||||||
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
default_pb_batch,
|
||||||
|
default_santacoder.tokenizer,
|
||||||
|
default_santacoder.dtype,
|
||||||
|
default_santacoder.device,
|
||||||
)
|
)
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
|
|||||||
default_santacoder, default_fim_pb_batch
|
default_santacoder, default_fim_pb_batch
|
||||||
):
|
):
|
||||||
batch = CausalLMBatch.from_pb(
|
batch = CausalLMBatch.from_pb(
|
||||||
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
default_fim_pb_batch,
|
||||||
|
default_santacoder.tokenizer,
|
||||||
|
default_santacoder.dtype,
|
||||||
|
default_santacoder.device,
|
||||||
)
|
)
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||||
return Seq2SeqLMBatch.from_pb(
|
return Seq2SeqLMBatch.from_pb(
|
||||||
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
|
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
|||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
return Seq2SeqLMBatch.from_pb(
|
||||||
|
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||||
|
@ -393,7 +393,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user