mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
parent
03c2123244
commit
a5c788cfe4
@ -823,7 +823,7 @@ class CausalLM(Model):
|
|||||||
if token_idx is None:
|
if token_idx is None:
|
||||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||||
else:
|
else:
|
||||||
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
|
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
|
||||||
|
|
||||||
# Slice unused values from prefill, use it to store next token
|
# Slice unused values from prefill, use it to store next token
|
||||||
if token_idx is None:
|
if token_idx is None:
|
||||||
|
@ -214,8 +214,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
# ignore logprobs if we use greedy search
|
# ignore logprobs if we use greedy search
|
||||||
if type(self.choice) == Greedy:
|
if type(self.choice) == Greedy:
|
||||||
logprobs = torch.zeros_like(scores, device="cpu")
|
logprobs = torch.empty_like(scores, device="cpu")
|
||||||
next_logprobs = torch.zeros_like(next_ids.view(-1), device="cpu")
|
next_logprobs = torch.empty_like(next_ids.view(-1), device="cpu")
|
||||||
else:
|
else:
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user