This commit is contained in:
Joel Lamy-Poirier 2023-05-03 15:17:06 -04:00
parent 4554a69b22
commit d5685656a4
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -216,9 +216,13 @@ class VectorizedNextTokenChooser:
else:
self.typical_p=None
self.do_sample = any(do_sample)
if self.do_sample and not all(do_sample):
raise NotImplementedError("Mixed greedy and probabilistic sampling not supported")
num_do_sample=sum(do_sample)
self.do_sample = num_do_sample>0
if self.do_sample and num_do_sample<self.batch_size:
# Mixed greedy and probabilistic sampling. Compute both and pick the right one.
self.do_sample_v=torch.tensor(do_sample, dtype=torch.bool, device=device)
else:
self.do_sample_v=None
def _standardize(self, values, default):
if isinstance(values, list):
@ -289,7 +293,10 @@ class VectorizedNextTokenChooser:
logprobs = torch.log_softmax(scores, dim=-1)
if self.do_sample:
raise NotImplementedError()
probs = torch.nn.functional.softmax(scores, -1)
next_token_ids = torch.multinomial(probs, num_samples=1)
if self.do_sample_v is not None:
next_token_ids=torch.where(self.do_sample_v, next_token_ids,torch.argmax(scores, dim=-1))
else:
next_token_ids = torch.argmax(scores, dim=-1)