Fix dtype mismatch in HeterogeneousFrequencyPenaltyLogitsProcessor (#163)

This commit is contained in:
Karol Damaszke 2024-07-03 10:57:41 +02:00 committed by GitHub
parent 30342ca82d
commit 4b4382c6f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
vocab_size = scores.size(1) vocab_size = scores.size(1)
# Calculate the frequency for each token so far # Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq = torch.zeros(
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
)
token_freq.scatter_add_( token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float) 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
) )
token_freq /= input_size token_freq /= input_size