From d969151a1e18ba82c7e7aa859ae4a1b35dd7bf1c Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 19 Apr 2024 03:23:30 +0000 Subject: [PATCH] fix: avoid frequency and repetition penalty on padding tokens --- server/text_generation_server/utils/logits_process.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 6d8cb71a..214cb32b 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -114,6 +114,8 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): score < 0, score * self.penalty_tensor, score / self.penalty_tensor ) + # set score to 0 where input_ids is a padding token + score *= input_ids.ne(0) scores.scatter_(1, input_ids, score) return scores @@ -143,6 +145,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability score = -torch.where(score < 0, score * self.penalty, score / self.penalty) + # set score to 0 where input_ids is a padding token + score *= input_ids.ne(0) return scores.scatter_add_(1, input_ids, score) @@ -168,6 +172,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): score = -torch.where( score < 0, score * self.penalty_tensor, score / self.penalty_tensor ) + # set score to 0 where input_ids is a padding token + score *= input_ids.ne(0) return scores.scatter_add_(1, input_ids, score)