mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
no-repeat-ngram is processor not warper
This commit is contained in:
parent
eb9e109b9c
commit
d3fc28ebe7
@ -18,21 +18,15 @@ from transformers import (
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
NoRepeatNGramLogitsProcessor
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from transformers.generation.logits_process import _calc_banned_ngram_tokens
|
||||||
|
|
||||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
|
||||||
class StaticWarper:
|
class StaticWarper:
|
||||||
def __init__(
|
def __init__(self, temperature=1.0, top_k=None, top_p=None, typical_p=None):
|
||||||
self,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=None,
|
|
||||||
top_p=None,
|
|
||||||
typical_p=None,
|
|
||||||
no_repeat_ngram_size=None,
|
|
||||||
):
|
|
||||||
self.warpers = []
|
self.warpers = []
|
||||||
|
|
||||||
if temperature is not None and temperature != 1.0:
|
if temperature is not None and temperature != 1.0:
|
||||||
@ -44,8 +38,6 @@ class StaticWarper:
|
|||||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
|
||||||
self.warpers.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
|
||||||
|
|
||||||
self.cuda_graph = None
|
self.cuda_graph = None
|
||||||
self.static_scores = None
|
self.static_scores = None
|
||||||
@ -86,10 +78,12 @@ def static_warper(
|
|||||||
top_k: Optional[int],
|
top_k: Optional[int],
|
||||||
top_p: Optional[float],
|
top_p: Optional[float],
|
||||||
typical_p: Optional[float],
|
typical_p: Optional[float],
|
||||||
no_repeat_ngram_size: Optional[int],
|
|
||||||
) -> StaticWarper:
|
) -> StaticWarper:
|
||||||
return StaticWarper(
|
return StaticWarper(
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
typical_p=typical_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,11 @@ from text_generation_server.utils.logits_process import (
|
|||||||
static_warper,
|
static_warper,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import (
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
NoRepeatNGramLogitsProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -59,6 +63,12 @@ class NextTokenChooser:
|
|||||||
if grammar != ""
|
if grammar != ""
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.no_repeat_ngram_processor = (
|
||||||
|
NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
|
||||||
|
if no_repeat_ngram_size and no_repeat_ngram_size > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
@ -66,11 +76,10 @@ class NextTokenChooser:
|
|||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
or (top_p is not None and top_p < 1.0)
|
or (top_p is not None and top_p < 1.0)
|
||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0)
|
|
||||||
)
|
)
|
||||||
if has_warpers:
|
if has_warpers:
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p, no_repeat_ngram_size=no_repeat_ngram_size
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
@ -90,6 +99,8 @@ class NextTokenChooser:
|
|||||||
scores = self.frequency_processor(input_ids, scores)
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||||
|
if self.no_repeat_ngram_processor is not None:
|
||||||
|
scores = self.no_repeat_ngram_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
@ -243,6 +254,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
frequency_penalty: List[float],
|
frequency_penalty: List[float],
|
||||||
|
no_repeat_ngram_size: List[int],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
top_p: List[float],
|
top_p: List[float],
|
||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
@ -291,6 +303,18 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.no_repeat_ngram_processor = (
|
||||||
|
HeterogeneousProcessorWrapper(
|
||||||
|
{
|
||||||
|
i: NoRepeatNGramLogitsProcessor(n)
|
||||||
|
for i, n in enumerate(no_repeat_ngram_size)
|
||||||
|
if n > 0
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if any([n > 0 for n in no_repeat_ngram_size])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if any(x != 1.0 for x in temperature):
|
if any(x != 1.0 for x in temperature):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
@ -357,6 +381,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
|
if self.no_repeat_ngram_processor is not None:
|
||||||
|
_scores = self.no_repeat_ngram_processor(input_ids, _scores)
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
@ -491,6 +517,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
temperature=[pb_.temperature for pb_ in pb],
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
||||||
|
no_repeat_ngram_size=[pb_.no_repeat_ngram_size for pb_ in pb],
|
||||||
top_k=[pb_.top_k for pb_ in pb],
|
top_k=[pb_.top_k for pb_ in pb],
|
||||||
top_p=[pb_.top_p for pb_ in pb],
|
top_p=[pb_.top_p for pb_ in pb],
|
||||||
typical_p=[pb_.typical_p for pb_ in pb],
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
|
Loading…
Reference in New Issue
Block a user