fixed HeterogeneousNextTokenChooser by using HeterogeneousProcessorWrapper with SequenceBiasLogitsProcessor

This commit is contained in:
marcusdunn 2023-08-10 09:49:55 -07:00
parent 8453eca41b
commit a64c2a6f89

View File

@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import (
class NextTokenChooser:
def __init__(
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
logit_bias=None,
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
logit_bias=None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -44,14 +44,14 @@ class NextTokenChooser:
else None
)
self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
SequenceBiasLogitsProcessor(sequence_bias=logit_bias)
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if has_warpers:
self.static_warper = static_warper(
@ -82,9 +82,9 @@ class NextTokenChooser:
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@ -113,11 +113,11 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
@ -143,9 +143,9 @@ class StoppingCriteria:
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -160,18 +160,18 @@ class StoppingCriteria:
class HeterogeneousNextTokenChooser:
def __init__(
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
logit_bias: Dict[Tuple[int], float],
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
logit_bias: List[Dict[Tuple[int], float]],
):
warpers = []
@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser:
)
self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
HeterogeneousProcessorWrapper({
i: SequenceBiasLogitsProcessor(
bias
)
for i, bias in enumerate(logit_bias)
if any([bias[k] != 0.0 for k in bias])
})
) if logit_bias else None
if any([x != 1.0 for x in temperature]):
do_sample = [
@ -275,10 +281,10 @@ class HeterogeneousNextTokenChooser:
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
logit_bias={},
)