Remove redundant fill op (#83) (#90)

Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-03-01 01:32:02 +01:00 committed by GitHub
parent 03c2123244
commit a5c788cfe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -823,7 +823,7 @@ class CausalLM(Model):
if token_idx is None:
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
# Slice unused values from prefill, use it to store next token
if token_idx is None:

View File

@ -214,8 +214,8 @@ class HeterogeneousNextTokenChooser:
next_ids = self.choice(scores)
# ignore logprobs if we use greedy search
if type(self.choice) == Greedy:
logprobs = torch.zeros_like(scores, device="cpu")
next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu")
logprobs = torch.empty_like(scores, device="cpu")
next_logprobs = torch.empty_like(next_ids.view(-1), device="cpu")
else:
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)