diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 4eef456a..c4775a09 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -305,7 +305,7 @@ class HeterogeneousSampling: self.greedy = Greedy() def __call__(self, logits): - out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) + out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing torch.argmax(logits, -1, out=out)