Enable server UT: test_causal_lm.py::test_batch_from_pb (#121)

Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
Jacek Czaja 2024-04-10 16:33:56 +02:00 committed by GitHub
parent 30cc78773e
commit ae6215fcea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,5 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import pytest
import torch
@ -5,7 +7,13 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.models.causal_lm import (
CausalLM,
CausalLMBatch,
PREFILL_BATCH_BUCKET_SIZE,
PAD_SEQUENCE_TO_MULTIPLE_OF
)
@pytest.fixture(scope="session")
@ -26,7 +34,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -40,7 +48,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("hpu")
)
@ -54,7 +62,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
batch_pb, gpt2_tokenizer, torch.float32, torch.device("hpu")
)
@ -62,30 +70,34 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.requests) == len(default_pb_batch.requests)
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
for r in range(0,len(default_pb_batch.requests)):
assert batch.requests[r].data == default_pb_batch.requests[r]
assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)
# For Gaudi we are adding padding of multiplication of bucket size
size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE
assert len(batch.input_ids) == size_of_padded_to_bucket
assert batch.input_ids[0][-2] == 14402
assert torch.all(batch.input_ids[0][:-2] == 50256)
assert batch.input_ids[0][-1] == 50256
assert batch.attention_mask[0, -2] == 1
assert torch.all(batch.attention_mask[0, :-2] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
for input_ids, request in zip(batch.input_ids, batch.requests)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate
def test_batch_concatenate_no_prefill(default_causal_lm_batch):