Fix the issue of out of range (#98)

Signed-off-by: yuanwu <yuan.wu@intel.com>
Co-authored-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yao Matrix 2024-03-13 17:09:53 +08:00 committed by GitHub
parent 602a920ec5
commit 7149ac30e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -305,7 +305,7 @@ class HeterogeneousSampling:
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
# Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out)