This commit is contained in:
Nicolas Patry 2024-05-08 06:33:13 +00:00
parent 1fde6850bb
commit 1a8a18d541

View File

@ -545,7 +545,7 @@ class MLPSpeculatorModel(torch.nn.Module):
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids[-b:].unsqueeze(0)
ind = input_ids.unsqueeze(0)
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
# log_probs = torch.zeros(1, b, device=state.device) # b k
all_probs = torch.empty(