diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 3c3bcb68..371a9ee4 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -296,7 +296,8 @@ class HeterogeneousSampling: def __call__(self, logits): out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: - out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) + # Computing for all indices is faster than slicing + torch.argmax(logits, -1, out=out) for i, sampling in self.sampling_mapping.items(): out[i] = sampling(logits[i])