mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: avoid frequency and repetition penalty on padding tokens
This commit is contained in:
parent
ed72e92126
commit
d969151a1e
@ -114,6 +114,8 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@ -143,6 +145,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
@ -168,6 +172,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score = -torch.where(
|
score = -torch.where(
|
||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
)
|
)
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user