mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
@njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec.
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
import pytest
|
|
|
|
from text_generation.pb import generate_pb2
|
|
from text_generation.models.causal_lm import CausalLMBatch
|
|
from text_generation.models.santacoder import SantaCoder
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def default_santacoder():
|
|
return SantaCoder("bigcode/santacoder")
|
|
|
|
|
|
@pytest.fixture
|
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|
return generate_pb2.Request(
|
|
id=0,
|
|
inputs="def",
|
|
input_length=1,
|
|
parameters=default_pb_parameters,
|
|
stopping_parameters=default_pb_stop_parameters,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def default_pb_batch(default_pb_request):
|
|
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
|
|
|
|
|
@pytest.fixture
|
|
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|
return generate_pb2.Request(
|
|
id=0,
|
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
|
input_length=5,
|
|
parameters=default_pb_parameters,
|
|
stopping_parameters=default_pb_stop_parameters,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def default_fim_pb_batch(default_fim_pb_request):
|
|
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
|
|
|
|
|
|
@pytest.mark.skip
|
|
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
|
batch = CausalLMBatch.from_pb(
|
|
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
|
)
|
|
next_batch = batch
|
|
|
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
|
assert len(generations) == len(next_batch)
|
|
|
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
|
assert next_batch is None
|
|
|
|
assert len(generations) == 1
|
|
assert generations[0].generated_text.text == " test_get_all_users_with_"
|
|
assert generations[0].request_id == batch.requests[0].id
|
|
assert (
|
|
generations[0].generated_text.generated_tokens
|
|
== batch.stopping_criterias[0].max_new_tokens
|
|
)
|
|
|
|
|
|
@pytest.mark.skip
|
|
def test_fim_santacoder_generate_token_completion(
|
|
default_santacoder, default_fim_pb_batch
|
|
):
|
|
batch = CausalLMBatch.from_pb(
|
|
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
|
)
|
|
next_batch = batch
|
|
|
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
|
assert len(generations) == len(next_batch)
|
|
|
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
|
assert next_batch is None
|
|
|
|
assert len(generations) == 1
|
|
assert (
|
|
generations[0].generated_text.text
|
|
== """ineProperty(exports, "__esModule", { value"""
|
|
)
|
|
assert generations[0].request_id == batch.requests[0].id
|
|
assert (
|
|
generations[0].generated_text.generated_tokens
|
|
== batch.stopping_criterias[0].max_new_tokens
|
|
)
|