mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
formatting
This commit is contained in:
parent
1c9d953962
commit
d67a2e22fa
@ -48,10 +48,10 @@ class NextTokenChooser:
|
||||
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
or (top_p is not None and top_p < 1.0)
|
||||
or (typical_p is not None and typical_p < 1.0)
|
||||
(temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
or (top_p is not None and top_p < 1.0)
|
||||
or (typical_p is not None and typical_p < 1.0)
|
||||
)
|
||||
if has_warpers:
|
||||
self.static_warper = static_warper(
|
||||
@ -82,10 +82,10 @@ class NextTokenChooser:
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
@ -116,11 +116,11 @@ class StopSequenceCriteria:
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self,
|
||||
eos_token_id: int,
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
self,
|
||||
eos_token_id: int,
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.stop_sequence_criterias = stop_sequence_criterias
|
||||
@ -146,9 +146,9 @@ class StoppingCriteria:
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.StoppingCriteriaParameters,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cls,
|
||||
pb: generate_pb2.StoppingCriteriaParameters,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
@ -163,18 +163,18 @@ class StoppingCriteria:
|
||||
|
||||
class HeterogeneousNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
top_p: List[float],
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
logit_bias: List[Dict[Tuple[int], float]],
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
top_p: List[float],
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
logit_bias: List[Dict[Tuple[int], float]],
|
||||
):
|
||||
warpers = []
|
||||
|
||||
@ -284,11 +284,11 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
|
Loading…
Reference in New Issue
Block a user