diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a3515aa1..9f6f6c1e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -547,7 +547,7 @@ class MLPSpeculatorModel(torch.nn.Module): b = state.size(0) ind = input_ids[-b:].unsqueeze(0) out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h - log_probs = torch.zeros(1, b, device=state.device) # b k + # log_probs = torch.zeros(1, b, device=state.device) # b k all_probs = torch.empty( 1, b, self.n_predict, self.vsize, device=state.device ) # b k h v @@ -561,14 +561,14 @@ class MLPSpeculatorModel(torch.nn.Module): z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d state = self.proj[i](state) * self.state_weight + z state = self.activation(self.ln[i](state)) # b k d - _probs = F.log_softmax(self.head[i](state), dim=-1) # b k v - probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + probs = F.log_softmax(self.head[i](state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' # Update candidate set with new predictions out[:, :, i : i + 1] = preds # Update distribution set with new logits - all_probs[:, :, i] = _probs.exp() + all_probs[:, :, i] = probs.exp() # Update state, log_probs and ind for new predictions state = state.unsqueeze(2).expand( @@ -576,21 +576,21 @@ class MLPSpeculatorModel(torch.nn.Module): ) # b k k' d state = state.reshape(-1, b, state.size(3)) # b kk' d ind = preds.view(-1, b) # b kk' - log_probs = log_probs.unsqueeze(2).expand( - -1, b, top_k_tokens_per_head[i] - ) # b k k' - log_probs = log_probs.add(probs).reshape(-1, b) # b kk' + # log_probs = log_probs.unsqueeze(2).expand( + # -1, b, top_k_tokens_per_head[i] + # ) # b k k' + # log_probs = log_probs.add(probs).reshape(-1, b) # b kk' # print("done") # Take only top n best guesses - best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k + # best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k # speculative_logits = all_probs.gather( # 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) # ).squeeze(0) speculative_logits = all_probs[0] # assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}" # TODO Why is this shift existing, are speculative logits also including the natural next token ? - return speculative_logits[:, 1:] + return speculative_logits class MLPSpeculatorHead(nn.Module): @@ -607,6 +607,7 @@ class MLPSpeculatorHead(nn.Module): if input.shape[0] > 128: return logits, None + input_ids = logits.argmax(dim=-1) speculative_logits = self.mlp_speculator(input, input_ids) return logits, speculative_logits