From 21ec5393ac7ece9333aeeb8f9f9e1656e4318150 Mon Sep 17 00:00:00 2001 From: martini Date: Tue, 30 Apr 2024 09:46:27 +0200 Subject: [PATCH] chore: rebase and fix formatting --- .gitignore | 2 ++ server/text_generation_server/utils/logits_process.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index b3ca772b..2ac2f6b4 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh server/exllama_kernels/exllama_kernels/hip_buffers.cuh server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp + +data/ diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 8cf85c9f..6b915437 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -146,7 +146,6 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): # set score to 0 where input_ids is a padding token score *= input_ids.ne(0) - return scores.scatter_add_(1, input_ids, score) @@ -172,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): # Calculate the frequency for each token so far token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) - token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float)) + token_freq.scatter_add_( + 1, input_ids, torch.ones_like(input_ids, dtype=torch.float) + ) token_freq /= input_size # Apply the frequency penalty to logits