From a62f14872e8846f0d0b7ea99511b45b6dfa05b98 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 24 May 2023 16:28:16 +0200 Subject: [PATCH] optimize argmax --- server/text_generation_server/utils/tokens.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 3c3bcb68..371a9ee4 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -296,7 +296,8 @@ class HeterogeneousSampling: def __call__(self, logits): out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: - out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) + # Computing for all indices is faster than slicing + torch.argmax(logits, -1, out=out) for i, sampling in self.sampling_mapping.items(): out[i] = sampling(logits[i])