mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fixed HeterogeneousNextTokenChooser
by using HeterogeneousProcessorWrapper
with SequenceBiasLogitsProcessor
This commit is contained in:
parent
8453eca41b
commit
a64c2a6f89
@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import (
|
|||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
watermark=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
typical_p=None,
|
typical_p=None,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
logit_bias=None,
|
logit_bias=None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -44,14 +44,14 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.sequence_bias_logits_processor = (
|
self.sequence_bias_logits_processor = (
|
||||||
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
|
SequenceBiasLogitsProcessor(sequence_bias=logit_bias)
|
||||||
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
or (top_p is not None and top_p < 1.0)
|
or (top_p is not None and top_p < 1.0)
|
||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
)
|
)
|
||||||
if has_warpers:
|
if has_warpers:
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(
|
||||||
@ -82,9 +82,9 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -113,11 +113,11 @@ class StopSequenceCriteria:
|
|||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
@ -143,9 +143,9 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
@ -160,18 +160,18 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
top_p: List[float],
|
top_p: List[float],
|
||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
logit_bias: Dict[Tuple[int], float],
|
logit_bias: List[Dict[Tuple[int], float]],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.sequence_bias_logits_processor = (
|
self.sequence_bias_logits_processor = (
|
||||||
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
|
HeterogeneousProcessorWrapper({
|
||||||
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
i: SequenceBiasLogitsProcessor(
|
||||||
|
bias
|
||||||
|
)
|
||||||
|
for i, bias in enumerate(logit_bias)
|
||||||
|
if any([bias[k] != 0.0 for k in bias])
|
||||||
|
})
|
||||||
|
) if logit_bias else None
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
@ -275,10 +281,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
logit_bias={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user