diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index ab02ba52..104fc2f0 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -159,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): vocab_size = scores.size(1) # Calculate the frequency for each token so far - token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) + token_freq = torch.zeros( + batch_size, vocab_size, dtype=scores.dtype, device=scores.device + ) token_freq.scatter_add_( - 1, input_ids, torch.ones_like(input_ids, dtype=torch.float) + 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) ) token_freq /= input_size