This commit is contained in:
OlivierDehaene 2023-05-10 18:20:28 +02:00
parent 3248fdfbd4
commit a944dd0fd5

View File

@ -29,7 +29,7 @@ class Sampling:
q = torch.empty_like(probs).exponential_(1, generator=self.generator) q = torch.empty_like(probs).exponential_(1, generator=self.generator)
torch.div(probs, q, out=q) torch.div(probs, q, out=q)
return torch.argmax(q, dim=-1, keepdim=True) return q.argmax()
class Greedy: class Greedy:
@ -107,36 +107,36 @@ class NextTokenChooser:
seed=0, seed=0,
device="cpu", device="cpu",
): ):
self.watermark_warper = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
) )
self.repetition_warper = ( self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty if repetition_penalty
else None else None
) )
sampling = ( has_warpers = (
do_sample (temperature is not None and temperature != 1.0)
or (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0) or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0) or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0) or (typical_p is not None and typical_p < 1.0)
) )
if sampling: if has_warpers:
self.choice = Sampling(seed, device)
self.static_warper = static_warper( self.static_warper = static_warper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
) )
else: else:
self.choice = Greedy()
self.static_warper = None self.static_warper = None
sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_warper: if self.watermark_processor:
scores = self.watermark_warper(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_warper: if self.repetition_processor:
scores = self.repetition_warper(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)