diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c7888f6f..5460ccfd 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -48,10 +48,10 @@ class NextTokenChooser: ) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None has_warpers = ( - (temperature is not None and temperature != 1.0) - or (top_k is not None and top_k != 0) - or (top_p is not None and top_p < 1.0) - or (typical_p is not None and typical_p < 1.0) + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) ) if has_warpers: self.static_warper = static_warper( @@ -82,10 +82,10 @@ class NextTokenChooser: @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -116,11 +116,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -146,9 +146,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences @@ -163,18 +163,18 @@ class StoppingCriteria: class HeterogeneousNextTokenChooser: def __init__( - self, - dtype: torch.dtype, - device: torch.device, - watermark: List[bool], - temperature: List[float], - repetition_penalty: List[float], - top_k: List[int], - top_p: List[float], - typical_p: List[float], - do_sample: List[bool], - seeds: List[int], - logit_bias: List[Dict[Tuple[int], float]], + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], + logit_bias: List[Dict[Tuple[int], float]], ): warpers = [] @@ -284,11 +284,11 @@ class HeterogeneousNextTokenChooser: @classmethod def from_pb( - cls, - pb: List[generate_pb2.NextTokenChooserParameters], - dtype: torch.dtype, - device: torch.device, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb],