mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Fix dtype mismatch in HeterogeneousFrequencyPenaltyLogitsProcessor (#163)
This commit is contained in:
parent
30342ca82d
commit
4b4382c6f8
@ -159,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
vocab_size = scores.size(1)
|
||||
|
||||
# Calculate the frequency for each token so far
|
||||
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||
token_freq = torch.zeros(
|
||||
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
|
||||
)
|
||||
token_freq.scatter_add_(
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
|
||||
)
|
||||
token_freq /= input_size
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user