mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Enable server UT: test_causal_lm.py::test_batch_from_pb (#121)
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
parent
30cc78773e
commit
ae6215fcea
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -5,7 +7,13 @@ from copy import copy
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
from text_generation_server.models.causal_lm import (
|
||||||
|
CausalLM,
|
||||||
|
CausalLMBatch,
|
||||||
|
PREFILL_BATCH_BUCKET_SIZE,
|
||||||
|
PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -26,7 +34,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
@ -40,7 +48,7 @@ def default_pb_batch(default_pb_request):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||||
return CausalLMBatch.from_pb(
|
return CausalLMBatch.from_pb(
|
||||||
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("hpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +62,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
|||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||||
return CausalLMBatch.from_pb(
|
return CausalLMBatch.from_pb(
|
||||||
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
batch_pb, gpt2_tokenizer, torch.float32, torch.device("hpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -62,30 +70,34 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
batch = default_causal_lm_batch
|
batch = default_causal_lm_batch
|
||||||
|
|
||||||
assert batch.batch_id == default_pb_batch.id
|
assert batch.batch_id == default_pb_batch.id
|
||||||
assert batch.requests == default_pb_batch.requests
|
assert len(batch.requests) == len(default_pb_batch.requests)
|
||||||
|
|
||||||
assert len(batch.input_ids) == default_pb_batch.size
|
for r in range(0,len(default_pb_batch.requests)):
|
||||||
assert batch.input_ids[0][-1] == 14402
|
assert batch.requests[r].data == default_pb_batch.requests[r]
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
|
||||||
|
|
||||||
assert batch.attention_mask[0, 0] == 1
|
# For Gaudi we are adding padding of multiplication of bucket size
|
||||||
assert torch.all(batch.attention_mask[0, 1:] == 0)
|
size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE
|
||||||
|
|
||||||
|
assert len(batch.input_ids) == size_of_padded_to_bucket
|
||||||
|
|
||||||
|
assert batch.input_ids[0][-2] == 14402
|
||||||
|
assert torch.all(batch.input_ids[0][:-2] == 50256)
|
||||||
|
assert batch.input_ids[0][-1] == 50256
|
||||||
|
|
||||||
|
assert batch.attention_mask[0, -2] == 1
|
||||||
|
assert torch.all(batch.attention_mask[0, :-2] == 0)
|
||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
assert all(
|
||||||
assert all(
|
|
||||||
[
|
[
|
||||||
torch.equal(input_ids, all_input_ids[:, 0])
|
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
|
||||||
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
for input_ids, request in zip(batch.input_ids, batch.requests)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert batch.input_lengths == [1]
|
|
||||||
|
|
||||||
assert len(batch) == default_pb_batch.size
|
assert len(batch) == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
|
||||||
|
|
||||||
assert batch.max_input_length == batch.input_lengths[0]
|
assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||||
|
Loading…
Reference in New Issue
Block a user