mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added missing imports of SequenceBiasLogitsProcessor
and typings.Dict
This commit is contained in:
parent
25c48f5679
commit
c70dea3802
@ -4,8 +4,9 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
SequenceBiasLogitsProcessor,
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
@ -166,6 +167,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self,
|
self,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
@ -239,6 +241,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -299,11 +302,12 @@ class HeterogeneousNextTokenChooser:
|
|||||||
typical_p=[pb_.typical_p for pb_ in pb],
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
do_sample=[pb_.do_sample for pb_ in pb],
|
do_sample=[pb_.do_sample for pb_ in pb],
|
||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
logit_bias=[dict(
|
logit_bias=[dict(
|
||||||
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
|
[(tuple(tokenizer.encode(bias.string, add_special_tokens=False).input_ids[0]), bias.bias) for bias in
|
||||||
pb_.logit_bias]) for pb_ in pb],
|
pb_.logit_bias]) for pb_ in pb],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user