mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix multinomial gpu cpu sync
This commit is contained in:
parent
1df2aa03c5
commit
3248fdfbd4
@ -25,8 +25,11 @@ class Sampling:
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||
return next_tokens
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
torch.div(probs, q, out=q)
|
||||
|
||||
return torch.argmax(q, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
class Greedy:
|
||||
|
Loading…
Reference in New Issue
Block a user