formatting

This commit is contained in:
marcusdunn 2023-08-10 10:57:34 -07:00
parent 9a0262f38c
commit 1c9d953962

View File

@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import (
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
logit_bias={}, logit_bias={},
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -97,7 +97,9 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]), logit_bias=dict(
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
pb.logit_bias]),
) )
@ -199,7 +201,7 @@ class HeterogeneousNextTokenChooser:
self.sequence_bias_logits_processor = ( self.sequence_bias_logits_processor = (
HeterogeneousProcessorWrapper({ HeterogeneousProcessorWrapper({
i: SequenceBiasLogitsProcessor( i: SequenceBiasLogitsProcessor(
bias bias
) )
for i, bias in enumerate(logit_bias) for i, bias in enumerate(logit_bias)
if any([bias[k] != 0.0 for k in bias]) if any([bias[k] != 0.0 for k in bias])
@ -299,7 +301,9 @@ bias
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb], logit_bias=[dict(
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
pb_.logit_bias]) for pb_ in pb],
) )