mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
sampling
This commit is contained in:
parent
4554a69b22
commit
d5685656a4
@ -216,9 +216,13 @@ class VectorizedNextTokenChooser:
|
||||
else:
|
||||
self.typical_p=None
|
||||
|
||||
self.do_sample = any(do_sample)
|
||||
if self.do_sample and not all(do_sample):
|
||||
raise NotImplementedError("Mixed greedy and probabilistic sampling not supported")
|
||||
num_do_sample=sum(do_sample)
|
||||
self.do_sample = num_do_sample>0
|
||||
if self.do_sample and num_do_sample<self.batch_size:
|
||||
# Mixed greedy and probabilistic sampling. Compute both and pick the right one.
|
||||
self.do_sample_v=torch.tensor(do_sample, dtype=torch.bool, device=device)
|
||||
else:
|
||||
self.do_sample_v=None
|
||||
|
||||
def _standardize(self, values, default):
|
||||
if isinstance(values, list):
|
||||
@ -289,7 +293,10 @@ class VectorizedNextTokenChooser:
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
|
||||
if self.do_sample:
|
||||
raise NotImplementedError()
|
||||
probs = torch.nn.functional.softmax(scores, -1)
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1)
|
||||
if self.do_sample_v is not None:
|
||||
next_token_ids=torch.where(self.do_sample_v, next_token_ids,torch.argmax(scores, dim=-1))
|
||||
else:
|
||||
next_token_ids = torch.argmax(scores, dim=-1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user