fixed HeterogeneousNextTokenChooser by using HeterogeneousProcessorWrapper with SequenceBiasLogitsProcessor

This commit is contained in:
marcusdunn 2023-08-10 09:49:55 -07:00
parent 8453eca41b
commit a64c2a6f89

View File

@ -23,17 +23,17 @@ from text_generation_server.utils.logits_process import (
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
logit_bias=None, logit_bias=None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -44,14 +44,14 @@ class NextTokenChooser:
else None else None
) )
self.sequence_bias_logits_processor = ( self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias) SequenceBiasLogitsProcessor(sequence_bias=logit_bias)
) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None ) if logit_bias and any([logit_bias[k] != 0.0 for k in logit_bias]) else None
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0) or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0) or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0) or (typical_p is not None and typical_p < 1.0)
) )
if has_warpers: if has_warpers:
self.static_warper = static_warper( self.static_warper = static_warper(
@ -82,9 +82,9 @@ class NextTokenChooser:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -113,11 +113,11 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
@ -143,9 +143,9 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.StoppingCriteriaParameters, pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -160,18 +160,18 @@ class StoppingCriteria:
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
top_k: List[int], top_k: List[int],
top_p: List[float], top_p: List[float],
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
logit_bias: Dict[Tuple[int], float], logit_bias: List[Dict[Tuple[int], float]],
): ):
warpers = [] warpers = []
@ -196,8 +196,14 @@ class HeterogeneousNextTokenChooser:
) )
self.sequence_bias_logits_processor = ( self.sequence_bias_logits_processor = (
SequenceBiasLogitsProcessor(sequence_bias = logit_bias) HeterogeneousProcessorWrapper({
) if any([logit_bias[k] != 0.0 for k in logit_bias]) else None i: SequenceBiasLogitsProcessor(
bias
)
for i, bias in enumerate(logit_bias)
if any([bias[k] != 0.0 for k in bias])
})
) if logit_bias else None
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
@ -275,10 +281,10 @@ class HeterogeneousNextTokenChooser:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -291,6 +297,7 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
logit_bias={},
) )