mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
optimize argmax
This commit is contained in:
parent
c59fb353a0
commit
a62f14872e
@ -296,7 +296,8 @@ class HeterogeneousSampling:
|
|||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
if self.greedy_indices:
|
if self.greedy_indices:
|
||||||
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
# Computing for all indices is faster than slicing
|
||||||
|
torch.argmax(logits, -1, out=out)
|
||||||
|
|
||||||
for i, sampling in self.sampling_mapping.items():
|
for i, sampling in self.sampling_mapping.items():
|
||||||
out[i] = sampling(logits[i])
|
out[i] = sampling(logits[i])
|
||||||
|
Loading…
Reference in New Issue
Block a user