From 8453eca41b7ef6f6a04813b79e8cbf441d033b12 Mon Sep 17 00:00:00 2001 From: marcusdunn Date: Thu, 10 Aug 2023 09:32:49 -0700 Subject: [PATCH] added logit_bias to the tgi server using `SequenceBiasLogitsProcessor` --- server/text_generation_server/utils/tokens.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b83af591..cbb3dafc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -33,6 +33,7 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", + logit_bias=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -42,6 +43,9 @@ class NextTokenChooser: if repetition_penalty else None ) + self.sequence_bias_logits_processor = ( + SequenceBiasLogitsProcessor(sequence_bias = logit_bias) + ) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None has_warpers = ( (temperature is not None and temperature != 1.0) @@ -64,6 +68,8 @@ class NextTokenChooser: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.sequence_bias_logits_processor is not None: + scores = self.sequence_bias_logits_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -90,6 +96,7 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + logit_bias=pb.logit_bias, ) @@ -164,6 +171,7 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], + logit_bias: Dict[Tuple[int], float], ): warpers = [] @@ -187,6 +195,10 @@ class HeterogeneousNextTokenChooser: else None ) + self.sequence_bias_logits_processor = ( + SequenceBiasLogitsProcessor(sequence_bias = logit_bias) + ) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -224,6 +236,8 @@ class HeterogeneousNextTokenChooser: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.sequence_bias_logits_processor is not None: + scores = self.sequence_bias_logits_processor(input_ids, scores) for warper in self.warpers: scores = warper(input_ids, scores)