mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove Warpers for Processor
This commit is contained in:
parent
b40c889360
commit
42ae6dea02
@ -11,7 +11,6 @@ from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
TemperatureLogitsWarper,
|
||||
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
r"""
|
||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||
Args:
|
||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||
processors (`Dict[int, LogitsProcessor]`):
|
||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||
processors: Dict[int, LogitsProcessor],
|
||||
):
|
||||
self.processors = processors
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user