mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
parent
03c2123244
commit
a5c788cfe4
@ -823,7 +823,7 @@ class CausalLM(Model):
|
||||
if token_idx is None:
|
||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||
else:
|
||||
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
|
||||
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
||||
|
||||
# Slice unused values from prefill, use it to store next token
|
||||
if token_idx is None:
|
||||
|
@ -214,8 +214,8 @@ class HeterogeneousNextTokenChooser:
|
||||
next_ids = self.choice(scores)
|
||||
# ignore logprobs if we use greedy search
|
||||
if type(self.choice) == Greedy:
|
||||
logprobs = torch.zeros_like(scores, device="cpu")
|
||||
next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu")
|
||||
logprobs = torch.empty_like(scores, device="cpu")
|
||||
next_logprobs = torch.empty_like(next_ids.view(-1), device="cpu")
|
||||
else:
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user