mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
add validation
This commit is contained in:
parent
c5a0b65c47
commit
1990d8633c
@ -14,9 +14,7 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=0, text="<pad>", logprob=None
|
||||
)
|
||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=3, text=" ", logprob=-1.984375, special=False
|
||||
@ -72,9 +70,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=0, text="<pad>", logprob=None
|
||||
)
|
||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=3, text=" ", logprob=-1.984375, special=False
|
||||
|
@ -11,6 +11,9 @@ def test_parameters_validation():
|
||||
Parameters(best_of=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(best_of=-1)
|
||||
Parameters(best_of=2, do_sample=True)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(best_of=2)
|
||||
|
||||
# Test repetition_penalty
|
||||
Parameters(repetition_penalty=1)
|
||||
@ -71,7 +74,9 @@ def test_request_validation():
|
||||
Request(inputs="")
|
||||
|
||||
Request(inputs="test", stream=True)
|
||||
Request(inputs="test", parameters=Parameters(best_of=2))
|
||||
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Request(inputs="test", parameters=Parameters(best_of=2), stream=True)
|
||||
Request(
|
||||
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
|
||||
)
|
||||
|
@ -6,8 +6,6 @@ from text_generation.errors import ValidationError
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
# Generate best_of sequences and return the one if the highest token logprobs
|
||||
best_of: Optional[int]
|
||||
# Activate logits sampling
|
||||
do_sample: bool = False
|
||||
# Maximum number of generated tokens
|
||||
@ -33,16 +31,29 @@ class Parameters(BaseModel):
|
||||
# Typical Decoding mass
|
||||
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
typical_p: Optional[float]
|
||||
# Generate best_of sequences and return the one if the highest token logprobs
|
||||
best_of: Optional[int]
|
||||
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
watermark: bool = False
|
||||
# Get generation details
|
||||
details: bool = False
|
||||
|
||||
@validator("best_of")
|
||||
def valid_best_of(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`best_of` must be strictly positive")
|
||||
return v
|
||||
def valid_best_of(cls, field_value, values):
|
||||
if field_value is not None:
|
||||
if field_value <= 0:
|
||||
raise ValidationError("`best_of` must be strictly positive")
|
||||
sampling = (
|
||||
values["do_sample"]
|
||||
| (values["temperature"] is not None)
|
||||
| (values["top_k"] is not None)
|
||||
| (values["top_p"] is not None)
|
||||
| (values["typical_p"] is not None)
|
||||
)
|
||||
if field_value > 1 and not sampling:
|
||||
raise ValidationError("you must use sampling when `best_of` is > 1")
|
||||
|
||||
return field_value
|
||||
|
||||
@validator("repetition_penalty")
|
||||
def valid_repetition_penalty(cls, v):
|
||||
@ -105,10 +116,10 @@ class Request(BaseModel):
|
||||
def valid_best_of_stream(cls, field_value, values):
|
||||
parameters = values["parameters"]
|
||||
if (
|
||||
parameters is not None
|
||||
and parameters.best_of is not None
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
parameters is not None
|
||||
and parameters.best_of is not None
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
):
|
||||
raise ValidationError(
|
||||
"`best_of` != 1 is not supported when `stream` == True"
|
||||
|
Loading…
Reference in New Issue
Block a user