diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 44ee4936..d8c7c2c6 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -20,18 +20,13 @@ from transformers import ( TypicalLogitsWarper, ) +from transformers.generation.logits_process import _calc_banned_ngram_tokens + mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None class StaticWarper: - def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, - no_repeat_ngram_size=None, - ): + def __init__(self, temperature=1.0, top_k=None, top_p=None, typical_p=None): self.warpers = [] if temperature is not None and temperature != 1.0: @@ -83,14 +78,12 @@ def static_warper( top_k: Optional[int], top_p: Optional[float], typical_p: Optional[float], - no_repeat_ngram_size: Optional[int], ) -> StaticWarper: return StaticWarper( temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, - no_repeat_ngram_size=no_repeat_ngram_size, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index f4b1e746..3766041e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -18,7 +18,11 @@ from text_generation_server.utils.logits_process import ( static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from transformers import ( + PreTrainedTokenizerBase, + RepetitionPenaltyLogitsProcessor, + NoRepeatNGramLogitsProcessor, +) class NextTokenChooser: @@ -58,6 +62,12 @@ class NextTokenChooser: if grammar != "" else None ) + + self.no_repeat_ngram_processor = ( + NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) + if no_repeat_ngram_size and no_repeat_ngram_size > 0 + else None + ) self.tokenizer = tokenizer has_warpers = ( @@ -65,15 +75,10 @@ class NextTokenChooser: or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) - or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0) ) if has_warpers: self.static_warper = static_warper( - temperature=temperature, - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - no_repeat_ngram_size=no_repeat_ngram_size, + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p ) else: self.static_warper = None @@ -93,6 +98,8 @@ class NextTokenChooser: scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(scores, self.fsm_grammar_state) + if self.no_repeat_ngram_processor is not None: + scores = self.no_repeat_ngram_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -246,6 +253,7 @@ class HeterogeneousNextTokenChooser: temperature: List[float], repetition_penalty: List[float], frequency_penalty: List[float], + no_repeat_ngram_size: List[int], top_k: List[int], top_p: List[float], typical_p: List[float], @@ -294,6 +302,18 @@ class HeterogeneousNextTokenChooser: else None ) + self.no_repeat_ngram_processor = ( + HeterogeneousProcessorWrapper( + { + i: NoRepeatNGramLogitsProcessor(n) + for i, n in enumerate(no_repeat_ngram_size) + if n > 0 + } + ) + if any([n > 0 for n in no_repeat_ngram_size]) + else None + ) + if any(x != 1.0 for x in temperature): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -360,6 +380,8 @@ class HeterogeneousNextTokenChooser: _scores = self.frequency_processor(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + if self.no_repeat_ngram_processor is not None: + _scores = self.no_repeat_ngram_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) @@ -494,6 +516,7 @@ class HeterogeneousNextTokenChooser: temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], frequency_penalty=[pb_.frequency_penalty for pb_ in pb], + no_repeat_ngram_size=[pb_.no_repeat_ngram_size for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],