From c70dea3802ca169e3962be26903e7ac5dfbccfa5 Mon Sep 17 00:00:00 2001 From: marcusdunn Date: Tue, 15 Aug 2023 15:13:07 -0700 Subject: [PATCH] added missing imports of `SequenceBiasLogitsProcessor` and `typings.Dict` --- server/text_generation_server/utils/tokens.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 5460ccfd..f6ca377b 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -4,8 +4,9 @@ import torch from transformers import ( RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, + SequenceBiasLogitsProcessor, ) -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason @@ -166,6 +167,7 @@ class HeterogeneousNextTokenChooser: self, dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, watermark: List[bool], temperature: List[float], repetition_penalty: List[float], @@ -239,6 +241,7 @@ class HeterogeneousNextTokenChooser: self.do_sample = do_sample self.dtype = dtype self.device = device + self.tokenizer = tokenizer def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): if self.watermark_processor is not None: @@ -299,11 +302,12 @@ class HeterogeneousNextTokenChooser: typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], - device=device, - dtype=dtype, logit_bias=[dict( [(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in pb_.logit_bias]) for pb_ in pb], + device=device, + dtype=dtype, + tokenizer=tokenizer, )