mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added missing imports of SequenceBiasLogitsProcessor
and typings.Dict
This commit is contained in:
parent
25c48f5679
commit
c70dea3802
@ -4,8 +4,9 @@ import torch
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
SequenceBiasLogitsProcessor,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
@ -166,6 +167,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
@ -239,6 +241,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.do_sample = do_sample
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||
if self.watermark_processor is not None:
|
||||
@ -299,11 +302,12 @@ class HeterogeneousNextTokenChooser:
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
logit_bias=[dict(
|
||||
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
|
||||
pb_.logit_bias]) for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user