diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index db3264e7..6f6e984d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -823,7 +823,7 @@ class CausalLM(Model): if token_idx is None: batch.input_ids[:, 0] = next_token_ids[:, 0] else: - batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1)) + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) # Slice unused values from prefill, use it to store next token if token_idx is None: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 2535f464..4eef456a 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -214,8 +214,8 @@ class HeterogeneousNextTokenChooser: next_ids = self.choice(scores) # ignore logprobs if we use greedy search if type(self.choice) == Greedy: - logprobs = torch.zeros_like(scores, device="cpu") - next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu") + logprobs = torch.empty_like(scores, device="cpu") + next_logprobs = torch.empty_like(next_ids.view(-1), device="cpu") else: logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)