From cfacf91af831d3b842541bb16fef6d0acc5ac2be Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:49:24 +0100 Subject: [PATCH] fix logits processor --- server/text_generation_server/utils/logits_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b9a4b691..291c522f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -136,10 +136,10 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability score = -torch.where( - score < 0, score * self.penalty_tensor, score / self.penalty_tensor + score < 0, score * self.penalty, score / self.penalty ) - return scores - torch.scatter(scores, 1, input_ids, score) + return scores.scatter_add_(1, input_ids, score) class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):