Fix dtype mismatch in HeterogeneousFrequencyPenaltyLogitsProcessor (#163)

This commit is contained in:
Karol Damaszke 2024-07-03 10:57:41 +02:00 committed by GitHub
parent 30342ca82d
commit 4b4382c6f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
vocab_size = scores.size(1)
# Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq = torch.zeros(
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
)
token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
)
token_freq /= input_size