From 3248fdfbd401eae70e106f24e8f68b535e8f7e36 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 10 May 2023 17:54:04 +0200 Subject: [PATCH] fix multinomial gpu cpu sync --- server/text_generation_server/utils/tokens.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 045f7100..789e2e85 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -25,8 +25,11 @@ class Sampling: def __call__(self, logits): probs = torch.nn.functional.softmax(logits, -1) - next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) - return next_tokens + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + torch.div(probs, q, out=q) + + return torch.argmax(q, dim=-1, keepdim=True) class Greedy: