Remove Warpers for Processor

This commit is contained in:
Cyril Vallez 2025-01-17 16:39:49 +00:00
parent b40c889360
commit 42ae6dea02
No known key found for this signature in database

View File

@ -11,7 +11,6 @@ from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from transformers import ( from transformers import (
LogitsWarper,
LogitsProcessor, LogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None return None
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsProcessor):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper): class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information. Generation](https://arxiv.org/abs/2202.00666) for more information.
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r""" r"""
A wrapper for logit warpers or processors without heterogeneous parameter support. A wrapper for logit warpers or processors without heterogeneous parameter support.
Args: Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially. A mapping of sample indices to logit warpers or processors, to be run sequentially.
""" """
def __init__( def __init__(
self, self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], processors: Dict[int, LogitsProcessor],
): ):
self.processors = processors self.processors = processors