diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index ef5b379a..86d0962e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -18,6 +18,12 @@ class Parameters(BaseModel): watermark: bool = False details: bool = False + @validator("repetition_penalty") + def valid_repetition_penalty(cls, v): + if v is not None and v is v < 0: + raise ValidationError("`repetition_penalty` must be strictly positive") + return v + @validator("seed") def valid_seed(cls, v): if v is not None and v is v < 0: diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index a390b710..50d64518 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,6 +1,6 @@ -from text_generation.utils.convert import convert_file, convert_files -from text_generation.utils.dist import initialize_torch_distributed -from text_generation.utils.hub import ( +from text_generation_server.utils.convert import convert_file, convert_files +from text_generation_server.utils.dist import initialize_torch_distributed +from text_generation_server.utils.hub import ( weight_files, weight_hub_files, download_weights, @@ -8,7 +8,7 @@ from text_generation.utils.hub import ( LocalEntryNotFoundError, RevisionNotFoundError, ) -from text_generation.utils.tokens import ( +from text_generation_server.utils.tokens import ( Greedy, NextTokenChooser, Sampling, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 00f4e64f..41dfabd3 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -12,8 +12,8 @@ from transformers import ( from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 -from text_generation.pb.generate_pb2 import FinishReason -from text_generation.utils.watermark import WatermarkLogitsProcessor +from text_generation_server.pb.generate_pb2 import FinishReason +from text_generation_server.utils.watermark import WatermarkLogitsProcessor class Sampling: