add missed commit

This commit is contained in:
Nathan Brake 2024-07-23 12:44:50 -04:00 committed by erikkaum
parent d8d3c4678e
commit 54b45be38d
2 changed files with 33 additions and 17 deletions

View File

@ -20,18 +20,13 @@ from transformers import (
TypicalLogitsWarper,
)
from transformers.generation.logits_process import _calc_banned_ngram_tokens
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
no_repeat_ngram_size=None,
):
def __init__(self, temperature=1.0, top_k=None, top_p=None, typical_p=None):
self.warpers = []
if temperature is not None and temperature != 1.0:
@ -83,14 +78,12 @@ def static_warper(
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
no_repeat_ngram_size: Optional[int],
) -> StaticWarper:
return StaticWarper(
temperature=temperature,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
no_repeat_ngram_size=no_repeat_ngram_size,
)

View File

@ -18,7 +18,11 @@ from text_generation_server.utils.logits_process import (
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from transformers import (
PreTrainedTokenizerBase,
RepetitionPenaltyLogitsProcessor,
NoRepeatNGramLogitsProcessor,
)
class NextTokenChooser:
@ -58,6 +62,12 @@ class NextTokenChooser:
if grammar != ""
else None
)
self.no_repeat_ngram_processor = (
NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
if no_repeat_ngram_size and no_repeat_ngram_size > 0
else None
)
self.tokenizer = tokenizer
has_warpers = (
@ -65,15 +75,10 @@ class NextTokenChooser:
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
or (no_repeat_ngram_size is not None and no_repeat_ngram_size > 0)
)
if has_warpers:
self.static_warper = static_warper(
temperature=temperature,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
no_repeat_ngram_size=no_repeat_ngram_size,
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
else:
self.static_warper = None
@ -93,6 +98,8 @@ class NextTokenChooser:
scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.no_repeat_ngram_processor is not None:
scores = self.no_repeat_ngram_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -246,6 +253,7 @@ class HeterogeneousNextTokenChooser:
temperature: List[float],
repetition_penalty: List[float],
frequency_penalty: List[float],
no_repeat_ngram_size: List[int],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
@ -294,6 +302,18 @@ class HeterogeneousNextTokenChooser:
else None
)
self.no_repeat_ngram_processor = (
HeterogeneousProcessorWrapper(
{
i: NoRepeatNGramLogitsProcessor(n)
for i, n in enumerate(no_repeat_ngram_size)
if n > 0
}
)
if any([n > 0 for n in no_repeat_ngram_size])
else None
)
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -360,6 +380,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.frequency_processor(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
if self.no_repeat_ngram_processor is not None:
_scores = self.no_repeat_ngram_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
@ -494,6 +516,7 @@ class HeterogeneousNextTokenChooser:
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
no_repeat_ngram_size=[pb_.no_repeat_ngram_size for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],