wrong import

This commit is contained in:
OlivierDehaene 2023-03-07 13:57:32 +01:00
parent f3586b2308
commit b7c3c7dabb
3 changed files with 12 additions and 6 deletions

View File

@ -18,6 +18,12 @@ class Parameters(BaseModel):
watermark: bool = False watermark: bool = False
details: bool = False details: bool = False
@validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v is v < 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed") @validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
if v is not None and v is v < 0: if v is not None and v is v < 0:

View File

@ -1,6 +1,6 @@
from text_generation.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
download_weights, download_weights,
@ -8,7 +8,7 @@ from text_generation.utils.hub import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
) )
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
Greedy, Greedy,
NextTokenChooser, NextTokenChooser,
Sampling, Sampling,

View File

@ -12,8 +12,8 @@ from transformers import (
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling: class Sampling: