diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9f6f6c1e..af00f5a3 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -545,7 +545,7 @@ class MLPSpeculatorModel(torch.nn.Module): # h indicates # of generated tokens state = hidden_states b = state.size(0) - ind = input_ids[-b:].unsqueeze(0) + ind = input_ids.unsqueeze(0) out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h # log_probs = torch.zeros(1, b, device=state.device) # b k all_probs = torch.empty(