chore: rebase and fix formatting

This commit is contained in:
martini 2024-04-30 09:46:27 +02:00 committed by Nicolas Patry
parent fcbd7fcd2e
commit 21ec5393ac
2 changed files with 5 additions and 2 deletions

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh *_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/

View File

@ -146,7 +146,6 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
# set score to 0 where input_ids is a padding token # set score to 0 where input_ids is a padding token
score *= input_ids.ne(0) score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -172,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
# Calculate the frequency for each token so far # Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float)) token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
)
token_freq /= input_size token_freq /= input_size
# Apply the frequency penalty to logits # Apply the frequency penalty to logits