added missing imports of SequenceBiasLogitsProcessor and typings.Dict

This commit is contained in:
marcusdunn 2023-08-15 15:13:07 -07:00
parent 25c48f5679
commit c70dea3802

View File

@ -4,8 +4,9 @@ import torch
from transformers import ( from transformers import (
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
SequenceBiasLogitsProcessor,
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Dict
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
@ -166,6 +167,7 @@ class HeterogeneousNextTokenChooser:
self, self,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
@ -239,6 +241,7 @@ class HeterogeneousNextTokenChooser:
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -299,11 +302,12 @@ class HeterogeneousNextTokenChooser:
typical_p=[pb_.typical_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
logit_bias=[dict( logit_bias=[dict(
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in [(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
pb_.logit_bias]) for pb_ in pb], pb_.logit_bias]) for pb_ in pb],
device=device,
dtype=dtype,
tokenizer=tokenizer,
) )