This commit is contained in:
OlivierDehaene 2023-05-10 18:40:26 +02:00
parent a944dd0fd5
commit b1f80702ef

View File

@ -26,8 +26,7 @@ class Sampling:
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
torch.div(probs, q, out=q)
q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs)
return q.argmax()