mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Fix the issue of out of range (#98)
Signed-off-by: yuanwu <yuan.wu@intel.com> Co-authored-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
602a920ec5
commit
7149ac30e6
@ -305,7 +305,7 @@ class HeterogeneousSampling:
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
# Computing for all indices is faster than slicing
|
||||
torch.argmax(logits, -1, out=out)
|
||||
|
Loading…
Reference in New Issue
Block a user