Fixed speculator.

This commit is contained in:
Nicolas Patry 2024-05-08 06:31:40 +00:00
parent 9291d42865
commit 1fde6850bb

View File

@ -547,7 +547,7 @@ class MLPSpeculatorModel(torch.nn.Module):
b = state.size(0)
ind = input_ids[-b:].unsqueeze(0)
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
log_probs = torch.zeros(1, b, device=state.device) # b k
# log_probs = torch.zeros(1, b, device=state.device) # b k
all_probs = torch.empty(
1, b, self.n_predict, self.vsize, device=state.device
) # b k h v
@ -561,14 +561,14 @@ class MLPSpeculatorModel(torch.nn.Module):
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
out[:, :, i : i + 1] = preds
# Update distribution set with new logits
all_probs[:, :, i] = _probs.exp()
all_probs[:, :, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
@ -576,21 +576,21 @@ class MLPSpeculatorModel(torch.nn.Module):
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
log_probs = log_probs.unsqueeze(2).expand(
-1, b, top_k_tokens_per_head[i]
) # b k k'
log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
# log_probs = log_probs.unsqueeze(2).expand(
# -1, b, top_k_tokens_per_head[i]
# ) # b k k'
# log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
# print("done")
# Take only top n best guesses
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
# best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
# speculative_logits = all_probs.gather(
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
# ).squeeze(0)
speculative_logits = all_probs[0]
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
return speculative_logits[:, 1:]
return speculative_logits
class MLPSpeculatorHead(nn.Module):
@ -607,6 +607,7 @@ class MLPSpeculatorHead(nn.Module):
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits