diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 132e441b..5066de53 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -11,7 +11,6 @@ from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.guide import RegexGuide from transformers import ( - LogitsWarper, LogitsProcessor, PreTrainedTokenizerBase, TemperatureLogitsWarper, @@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper: return None -class HeterogeneousTopPLogitsWarper(LogitsWarper): +class HeterogeneousTopPLogitsWarper(LogitsProcessor): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. This version allows for a separate value for each sample and runs inplace when possible. @@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): return None -class HeterogeneousTopKLogitsWarper(LogitsWarper): +class HeterogeneousTopKLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. This version allows for a separate value for each sample and runs inplace when possible. @@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): return None -class HeterogeneousTypicalLogitsWarper(LogitsWarper): +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): r""" [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. @@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. Args: - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + processors (`Dict[int, LogitsProcessor]`): A mapping of sample indices to logit warpers or processors, to be run sequentially. """ def __init__( self, - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + processors: Dict[int, LogitsProcessor], ): self.processors = processors