fix: avoid frequency and repetition penalty on padding tokens

This commit is contained in:
drbh 2024-04-19 03:23:30 +00:00
parent ed72e92126
commit d969151a1e

View File

@ -114,6 +114,8 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
score < 0, score * self.penalty_tensor, score / self.penalty_tensor score < 0, score * self.penalty_tensor, score / self.penalty_tensor
) )
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
scores.scatter_(1, input_ids, score) scores.scatter_(1, input_ids, score)
return scores return scores
@ -143,6 +145,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty) score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -168,6 +172,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = -torch.where( score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor score < 0, score * self.penalty_tensor, score / self.penalty_tensor
) )
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)