optimize argmax

This commit is contained in:
OlivierDehaene 2023-05-24 16:28:16 +02:00
parent c59fb353a0
commit a62f14872e

View File

@ -296,7 +296,8 @@ class HeterogeneousSampling:
def __call__(self, logits): def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices: if self.greedy_indices:
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) # Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out)
for i, sampling in self.sampling_mapping.items(): for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i]) out[i] = sampling(logits[i])