add validation

This commit is contained in:
OlivierDehaene 2023-03-09 15:04:59 +01:00
parent c5a0b65c47
commit 1990d8633c
3 changed files with 30 additions and 18 deletions

View File

@ -14,9 +14,7 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
id=0, text="<pad>", logprob=None
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False id=3, text=" ", logprob=-1.984375, special=False
@ -72,9 +70,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
id=0, text="<pad>", logprob=None
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False id=3, text=" ", logprob=-1.984375, special=False

View File

@ -11,6 +11,9 @@ def test_parameters_validation():
Parameters(best_of=0) Parameters(best_of=0)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Parameters(best_of=-1) Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
# Test repetition_penalty # Test repetition_penalty
Parameters(repetition_penalty=1) Parameters(repetition_penalty=1)
@ -71,7 +74,9 @@ def test_request_validation():
Request(inputs="") Request(inputs="")
Request(inputs="test", stream=True) Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2)) Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Request(inputs="test", parameters=Parameters(best_of=2), stream=True) Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)

View File

@ -6,8 +6,6 @@ from text_generation.errors import ValidationError
class Parameters(BaseModel): class Parameters(BaseModel):
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
# Maximum number of generated tokens # Maximum number of generated tokens
@ -33,16 +31,29 @@ class Parameters(BaseModel):
# Typical Decoding mass # Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float] typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False watermark: bool = False
# Get generation details # Get generation details
details: bool = False details: bool = False
@validator("best_of") @validator("best_of")
def valid_best_of(cls, v): def valid_best_of(cls, field_value, values):
if v is not None and v <= 0: if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive") raise ValidationError("`best_of` must be strictly positive")
return v sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
@validator("repetition_penalty") @validator("repetition_penalty")
def valid_repetition_penalty(cls, v): def valid_repetition_penalty(cls, v):