mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove Warpers for Processor
This commit is contained in:
parent
b40c889360
commit
42ae6dea02
@ -11,7 +11,6 @@ from text_generation_server.pb.generate_pb2 import GrammarType
|
|||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|||||||
r"""
|
r"""
|
||||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||||
Args:
|
Args:
|
||||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
processors (`Dict[int, LogitsProcessor]`):
|
||||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
processors: Dict[int, LogitsProcessor],
|
||||||
):
|
):
|
||||||
self.processors = processors
|
self.processors = processors
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user