From 1c9d9539622b6888055ec11a524037db0636eddf Mon Sep 17 00:00:00 2001 From: marcusdunn Date: Thu, 10 Aug 2023 10:57:34 -0700 Subject: [PATCH] formatting --- server/text_generation_server/utils/tokens.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 1368f47d..c7888f6f 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import ( class NextTokenChooser: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", - logit_bias={}, + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", + logit_bias={}, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -97,7 +97,9 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, - logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]), + logit_bias=dict( + [(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in + pb.logit_bias]), ) @@ -199,7 +201,7 @@ class HeterogeneousNextTokenChooser: self.sequence_bias_logits_processor = ( HeterogeneousProcessorWrapper({ i: SequenceBiasLogitsProcessor( -bias + bias ) for i, bias in enumerate(logit_bias) if any([bias[k] != 0.0 for k in bias]) @@ -299,7 +301,9 @@ bias seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, - logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb], + logit_bias=[dict( + [(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in + pb_.logit_bias]) for pb_ in pb], )