mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix logits processor
This commit is contained in:
parent
75b492d720
commit
cfacf91af8
@ -136,10 +136,10 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||
score = -torch.where(
|
||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||
score < 0, score * self.penalty, score / self.penalty
|
||||
)
|
||||
|
||||
return scores - torch.scatter(scores, 1, input_ids, score)
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
|
||||
|
||||
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
|
Loading…
Reference in New Issue
Block a user