mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
chore: rebase and fix formatting
This commit is contained in:
parent
fcbd7fcd2e
commit
21ec5393ac
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ server/exllama_kernels/exllama_kernels/hip_func/
|
|||||||
*_hip.cuh
|
*_hip.cuh
|
||||||
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||||
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||||
|
|
||||||
|
data/
|
||||||
|
@ -146,7 +146,6 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
# set score to 0 where input_ids is a padding token
|
# set score to 0 where input_ids is a padding token
|
||||||
score *= input_ids.ne(0)
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
|
||||||
@ -172,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
# Calculate the frequency for each token so far
|
# Calculate the frequency for each token so far
|
||||||
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||||
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float))
|
token_freq.scatter_add_(
|
||||||
|
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||||
|
)
|
||||||
token_freq /= input_size
|
token_freq /= input_size
|
||||||
|
|
||||||
# Apply the frequency penalty to logits
|
# Apply the frequency penalty to logits
|
||||||
|
Loading…
Reference in New Issue
Block a user