mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Enable server UT: test_causal_lm.py::test_batch_from_pb (#121)
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
parent
30cc78773e
commit
ae6215fcea
@ -1,3 +1,5 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -5,7 +7,13 @@ from copy import copy
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||
from text_generation_server.models.causal_lm import (
|
||||
CausalLM,
|
||||
CausalLMBatch,
|
||||
PREFILL_BATCH_BUCKET_SIZE,
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF
|
||||
)
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -26,7 +34,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
id=0,
|
||||
inputs="Test",
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
truncate=PAD_SEQUENCE_TO_MULTIPLE_OF,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
@ -40,7 +48,7 @@ def default_pb_batch(default_pb_request):
|
||||
@pytest.fixture
|
||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||
return CausalLMBatch.from_pb(
|
||||
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("hpu")
|
||||
)
|
||||
|
||||
|
||||
@ -54,7 +62,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(
|
||||
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
batch_pb, gpt2_tokenizer, torch.float32, torch.device("hpu")
|
||||
)
|
||||
|
||||
|
||||
@ -62,30 +70,34 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
batch = default_causal_lm_batch
|
||||
|
||||
assert batch.batch_id == default_pb_batch.id
|
||||
assert batch.requests == default_pb_batch.requests
|
||||
assert len(batch.requests) == len(default_pb_batch.requests)
|
||||
|
||||
assert len(batch.input_ids) == default_pb_batch.size
|
||||
assert batch.input_ids[0][-1] == 14402
|
||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
||||
for r in range(0,len(default_pb_batch.requests)):
|
||||
assert batch.requests[r].data == default_pb_batch.requests[r]
|
||||
|
||||
assert batch.attention_mask[0, 0] == 1
|
||||
assert torch.all(batch.attention_mask[0, 1:] == 0)
|
||||
# For Gaudi we are adding padding of multiplication of bucket size
|
||||
size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE
|
||||
|
||||
assert len(batch.input_ids) == size_of_padded_to_bucket
|
||||
|
||||
assert batch.input_ids[0][-2] == 14402
|
||||
assert torch.all(batch.input_ids[0][:-2] == 50256)
|
||||
assert batch.input_ids[0][-1] == 50256
|
||||
|
||||
assert batch.attention_mask[0, -2] == 1
|
||||
assert torch.all(batch.attention_mask[0, :-2] == 0)
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert all(
|
||||
assert all(
|
||||
[
|
||||
torch.equal(input_ids, all_input_ids[:, 0])
|
||||
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
|
||||
torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0])
|
||||
for input_ids, request in zip(batch.input_ids, batch.requests)
|
||||
]
|
||||
)
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
assert len(batch) == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
|
||||
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||
|
Loading…
Reference in New Issue
Block a user