mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fixed HeterogeneousNextTokenChooser
by using HeterogeneousProcessorWrapper
with SequenceBiasLogitsProcessor
This commit is contained in:
parent
8453eca41b
commit
a64c2a6f89
@ -171,7 +171,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
logit_bias: Dict[Tuple[int], float],
|
logit_bias: List[Dict[Tuple[int], float]],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.sequence_bias_logits_processor = (
|
self.sequence_bias_logits_processor = (
|
||||||
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
|
HeterogeneousProcessorWrapper({
|
||||||
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
i: SequenceBiasLogitsProcessor(
|
||||||
|
bias
|
||||||
|
)
|
||||||
|
for i, bias in enumerate(logit_bias)
|
||||||
|
if any([bias[k] != 0.0 for k in bias])
|
||||||
|
})
|
||||||
|
) if logit_bias else None
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
logit_bias={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user