added logit_bias to the tgi server using SequenceBiasLogitsProcessor

This commit is contained in:
marcusdunn 2023-08-10 09:32:49 -07:00
parent 20b05bc8ba
commit 8453eca41b

View File

@ -33,6 +33,7 @@ class NextTokenChooser:
do_sample=False,
seed=0,
device="cpu",
logit_bias=None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -42,6 +43,9 @@ class NextTokenChooser:
if repetition_penalty
else None
)
self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
has_warpers = (
(temperature is not None and temperature != 1.0)
@ -64,6 +68,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.sequence_bias_logits_processor is not None:
scores = self.sequence_bias_logits_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -90,6 +96,7 @@ class NextTokenChooser:
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
logit_bias=pb.logit_bias,
)
@ -164,6 +171,7 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
logit_bias: Dict[Tuple[int], float],
):
warpers = []
@ -187,6 +195,10 @@ class HeterogeneousNextTokenChooser:
else None
)
self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -224,6 +236,8 @@ class HeterogeneousNextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.sequence_bias_logits_processor is not None:
scores = self.sequence_bias_logits_processor(input_ids, scores)
for warper in self.warpers:
scores = warper(input_ids, scores)