diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index cbb3dafc..f99405ad 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import ( class NextTokenChooser: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", - logit_bias=None, + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", + logit_bias=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -44,14 +44,14 @@ class NextTokenChooser: else None ) self.sequence_bias_logits_processor = ( - SequenceBiasLogitsProcessor(sequence_bias = logit_bias) + SequenceBiasLogitsProcessor(sequence_bias=logit_bias) ) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None has_warpers = ( - (temperature is not None and temperature != 1.0) - or (top_k is not None and top_k != 0) - or (top_p is not None and top_p < 1.0) - or (typical_p is not None and typical_p < 1.0) + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) ) if has_warpers: self.static_warper = static_warper( @@ -82,9 +82,9 @@ class NextTokenChooser: @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -113,11 +113,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -143,9 +143,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences @@ -160,18 +160,18 @@ class StoppingCriteria: class HeterogeneousNextTokenChooser: def __init__( - self, - dtype: torch.dtype, - device: torch.device, - watermark: List[bool], - temperature: List[float], - repetition_penalty: List[float], - top_k: List[int], - top_p: List[float], - typical_p: List[float], - do_sample: List[bool], - seeds: List[int], - logit_bias: Dict[Tuple[int], float], + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], + logit_bias: List[Dict[Tuple[int], float]], ): warpers = [] @@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser: ) self.sequence_bias_logits_processor = ( - SequenceBiasLogitsProcessor(sequence_bias = logit_bias) - ) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None + HeterogeneousProcessorWrapper({ + i: SequenceBiasLogitsProcessor( +bias + ) + for i, bias in enumerate(logit_bias) + if any([bias[k] != 0.0 for k in bias]) + }) + ) if logit_bias else None if any([x != 1.0 for x in temperature]): do_sample = [ @@ -275,10 +281,10 @@ class HeterogeneousNextTokenChooser: @classmethod def from_pb( - cls, - pb: List[generate_pb2.NextTokenChooserParameters], - dtype: torch.dtype, - device: torch.device, + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + logit_bias={}, )