fix logits processor

This commit is contained in:
OlivierDehaene 2024-02-08 12:49:24 +01:00
parent 75b492d720
commit cfacf91af8

View File

@ -136,10 +136,10 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where( score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor score < 0, score * self.penalty, score / self.penalty
) )
return scores - torch.scatter(scores, 1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):