From a5c788cfe48e27f86a0e7a6f9e909c00f1c060c3 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Fri, 1 Mar 2024 01:32:02 +0100 Subject: [PATCH] Remove redundant fill op (#83) (#90) Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com> --- server/text_generation_server/models/causal_lm.py | 2 +- server/text_generation_server/utils/tokens.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index db3264e7..6f6e984d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -823,7 +823,7 @@ class CausalLM(Model): if token_idx is None: batch.input_ids[:, 0] = next_token_ids[:, 0] else: - batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1)) + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) # Slice unused values from prefill, use it to store next token if token_idx is None: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 2535f464..4eef456a 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -214,8 +214,8 @@ class HeterogeneousNextTokenChooser: next_ids = self.choice(scores) # ignore logprobs if we use greedy search if type(self.choice) == Greedy: - logprobs = torch.zeros_like(scores, device="cpu") - next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu") + logprobs = torch.empty_like(scores, device="cpu") + next_logprobs = torch.empty_like(next_ids.view(-1), device="cpu") else: logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)