fixed HeterogeneousNextTokenChooser by using HeterogeneousProcessorWrapper with SequenceBiasLogitsProcessor

This commit is contained in:
marcusdunn 2023-08-10 09:49:55 -07:00
parent 8453eca41b
commit a64c2a6f89

View File

@ -171,7 +171,7 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
logit_bias: Dict[Tuple[int], float],
logit_bias: List[Dict[Tuple[int], float]],
):
warpers = []
@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser:
)
self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
HeterogeneousProcessorWrapper({
i: SequenceBiasLogitsProcessor(
bias
)
for i, bias in enumerate(logit_bias)
if any([bias[k] != 0.0 for k in bias])
})
) if logit_bias else None
if any([x != 1.0 for x in temperature]):
do_sample = [
@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
logit_bias={},
)