mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added logit_bias to the tgi server using SequenceBiasLogitsProcessor
This commit is contained in:
parent
20b05bc8ba
commit
8453eca41b
@ -33,6 +33,7 @@ class NextTokenChooser:
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
logit_bias=None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -42,6 +43,9 @@ class NextTokenChooser:
|
||||
if repetition_penalty
|
||||
else None
|
||||
)
|
||||
self.sequence_bias_logits_processor = (
|
||||
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
|
||||
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
@ -64,6 +68,8 @@ class NextTokenChooser:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.sequence_bias_logits_processor is not None:
|
||||
scores = self.sequence_bias_logits_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -90,6 +96,7 @@ class NextTokenChooser:
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
logit_bias=pb.logit_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -164,6 +171,7 @@ class HeterogeneousNextTokenChooser:
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
logit_bias: Dict[Tuple[int], float],
|
||||
):
|
||||
warpers = []
|
||||
|
||||
@ -187,6 +195,10 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
self.sequence_bias_logits_processor = (
|
||||
SequenceBiasLogitsProcessor(sequence_bias = logit_bias)
|
||||
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
@ -224,6 +236,8 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.sequence_bias_logits_processor is not None:
|
||||
scores = self.sequence_bias_logits_processor(input_ids, scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
scores = warper(input_ids, scores)
|
||||
|
Loading…
Reference in New Issue
Block a user