mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
formatting
This commit is contained in:
parent
9a0262f38c
commit
1c9d953962
@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import (
|
|||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
watermark=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
typical_p=None,
|
typical_p=None,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
logit_bias={},
|
logit_bias={},
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -97,7 +97,9 @@ class NextTokenChooser:
|
|||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
logit_bias=dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb.logit_bias]),
|
logit_bias=dict(
|
||||||
|
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
|
||||||
|
pb.logit_bias]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -199,7 +201,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.sequence_bias_logits_processor = (
|
self.sequence_bias_logits_processor = (
|
||||||
HeterogeneousProcessorWrapper({
|
HeterogeneousProcessorWrapper({
|
||||||
i: SequenceBiasLogitsProcessor(
|
i: SequenceBiasLogitsProcessor(
|
||||||
bias
|
bias
|
||||||
)
|
)
|
||||||
for i, bias in enumerate(logit_bias)
|
for i, bias in enumerate(logit_bias)
|
||||||
if any([bias[k] != 0.0 for k in bias])
|
if any([bias[k] != 0.0 for k in bias])
|
||||||
@ -299,7 +301,9 @@ bias
|
|||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
logit_bias=[dict([(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb],
|
logit_bias=[dict(
|
||||||
|
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
|
||||||
|
pb_.logit_bias]) for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user