diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 0f9dab2c..e467d291 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import pytest import torch @@ -5,7 +7,13 @@ from copy import copy from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch +from text_generation_server.models.causal_lm import ( + CausalLM, + CausalLMBatch, + PREFILL_BATCH_BUCKET_SIZE, + PAD_SEQUENCE_TO_MULTIPLE_OF +) + @pytest.fixture(scope="session") @@ -26,7 +34,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): id=0, inputs="Test", prefill_logprobs=True, - truncate=100, + truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -40,7 +48,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): return CausalLMBatch.from_pb( - default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") + default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("hpu") ) @@ -54,7 +62,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb( - batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") + batch_pb, gpt2_tokenizer, torch.float32, torch.device("hpu") ) @@ -62,30 +70,34 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): batch = default_causal_lm_batch assert batch.batch_id == default_pb_batch.id - assert batch.requests == default_pb_batch.requests + assert len(batch.requests) == len(default_pb_batch.requests) - assert len(batch.input_ids) == default_pb_batch.size - assert batch.input_ids[0][-1] == 14402 - assert torch.all(batch.input_ids[0][:-1] == 50256) + for r in range(0,len(default_pb_batch.requests)): + assert batch.requests[r].data == default_pb_batch.requests[r] - assert batch.attention_mask[0, 0] == 1 - assert torch.all(batch.attention_mask[0, 1:] == 0) + # For Gaudi we are adding padding of multiplication of bucket size + size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE + + assert len(batch.input_ids) == size_of_padded_to_bucket + + assert batch.input_ids[0][-2] == 14402 + assert torch.all(batch.input_ids[0][:-2] == 50256) + assert batch.input_ids[0][-1] == 50256 + + assert batch.attention_mask[0, -2] == 1 + assert torch.all(batch.attention_mask[0, :-2] == 0) assert batch.past_key_values is None - - assert all( + assert all( [ - torch.equal(input_ids, all_input_ids[:, 0]) - for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0]) + for input_ids, request in zip(batch.input_ids, batch.requests) ] ) - assert batch.input_lengths == [1] - assert len(batch) == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) - assert batch.max_input_length == batch.input_lengths[0] + assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate def test_batch_concatenate_no_prefill(default_causal_lm_batch):