mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add validation
This commit is contained in:
parent
c5a0b65c47
commit
1990d8633c
@ -14,9 +14,7 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
|||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
id=0, text="<pad>", logprob=None
|
|
||||||
)
|
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=3, text=" ", logprob=-1.984375, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
@ -72,9 +70,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
|||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
id=0, text="<pad>", logprob=None
|
|
||||||
)
|
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=3, text=" ", logprob=-1.984375, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
|
@ -11,6 +11,9 @@ def test_parameters_validation():
|
|||||||
Parameters(best_of=0)
|
Parameters(best_of=0)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Parameters(best_of=-1)
|
Parameters(best_of=-1)
|
||||||
|
Parameters(best_of=2, do_sample=True)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=2)
|
||||||
|
|
||||||
# Test repetition_penalty
|
# Test repetition_penalty
|
||||||
Parameters(repetition_penalty=1)
|
Parameters(repetition_penalty=1)
|
||||||
@ -71,7 +74,9 @@ def test_request_validation():
|
|||||||
Request(inputs="")
|
Request(inputs="")
|
||||||
|
|
||||||
Request(inputs="test", stream=True)
|
Request(inputs="test", stream=True)
|
||||||
Request(inputs="test", parameters=Parameters(best_of=2))
|
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Request(inputs="test", parameters=Parameters(best_of=2), stream=True)
|
Request(
|
||||||
|
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
|
||||||
|
)
|
||||||
|
@ -6,8 +6,6 @@ from text_generation.errors import ValidationError
|
|||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
# Generate best_of sequences and return the one if the highest token logprobs
|
|
||||||
best_of: Optional[int]
|
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
# Maximum number of generated tokens
|
# Maximum number of generated tokens
|
||||||
@ -33,16 +31,29 @@ class Parameters(BaseModel):
|
|||||||
# Typical Decoding mass
|
# Typical Decoding mass
|
||||||
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
typical_p: Optional[float]
|
typical_p: Optional[float]
|
||||||
|
# Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
best_of: Optional[int]
|
||||||
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
watermark: bool = False
|
watermark: bool = False
|
||||||
# Get generation details
|
# Get generation details
|
||||||
details: bool = False
|
details: bool = False
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, v):
|
def valid_best_of(cls, field_value, values):
|
||||||
if v is not None and v <= 0:
|
if field_value is not None:
|
||||||
raise ValidationError("`best_of` must be strictly positive")
|
if field_value <= 0:
|
||||||
return v
|
raise ValidationError("`best_of` must be strictly positive")
|
||||||
|
sampling = (
|
||||||
|
values["do_sample"]
|
||||||
|
| (values["temperature"] is not None)
|
||||||
|
| (values["top_k"] is not None)
|
||||||
|
| (values["top_p"] is not None)
|
||||||
|
| (values["typical_p"] is not None)
|
||||||
|
)
|
||||||
|
if field_value > 1 and not sampling:
|
||||||
|
raise ValidationError("you must use sampling when `best_of` is > 1")
|
||||||
|
|
||||||
|
return field_value
|
||||||
|
|
||||||
@validator("repetition_penalty")
|
@validator("repetition_penalty")
|
||||||
def valid_repetition_penalty(cls, v):
|
def valid_repetition_penalty(cls, v):
|
||||||
@ -105,10 +116,10 @@ class Request(BaseModel):
|
|||||||
def valid_best_of_stream(cls, field_value, values):
|
def valid_best_of_stream(cls, field_value, values):
|
||||||
parameters = values["parameters"]
|
parameters = values["parameters"]
|
||||||
if (
|
if (
|
||||||
parameters is not None
|
parameters is not None
|
||||||
and parameters.best_of is not None
|
and parameters.best_of is not None
|
||||||
and parameters.best_of > 1
|
and parameters.best_of > 1
|
||||||
and field_value
|
and field_value
|
||||||
):
|
):
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"`best_of` != 1 is not supported when `stream` == True"
|
"`best_of` != 1 is not supported when `stream` == True"
|
||||||
|
Loading…
Reference in New Issue
Block a user