mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixed speculator.
This commit is contained in:
parent
9291d42865
commit
1fde6850bb
@ -547,7 +547,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||
b = state.size(0)
|
||||
ind = input_ids[-b:].unsqueeze(0)
|
||||
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
|
||||
log_probs = torch.zeros(1, b, device=state.device) # b k
|
||||
# log_probs = torch.zeros(1, b, device=state.device) # b k
|
||||
all_probs = torch.empty(
|
||||
1, b, self.n_predict, self.vsize, device=state.device
|
||||
) # b k h v
|
||||
@ -561,14 +561,14 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||
state = self.proj[i](state) * self.state_weight + z
|
||||
state = self.activation(self.ln[i](state)) # b k d
|
||||
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||
|
||||
# Update candidate set with new predictions
|
||||
out[:, :, i : i + 1] = preds
|
||||
|
||||
# Update distribution set with new logits
|
||||
all_probs[:, :, i] = _probs.exp()
|
||||
all_probs[:, :, i] = probs.exp()
|
||||
|
||||
# Update state, log_probs and ind for new predictions
|
||||
state = state.unsqueeze(2).expand(
|
||||
@ -576,21 +576,21 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||
) # b k k' d
|
||||
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||
ind = preds.view(-1, b) # b kk'
|
||||
log_probs = log_probs.unsqueeze(2).expand(
|
||||
-1, b, top_k_tokens_per_head[i]
|
||||
) # b k k'
|
||||
log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
||||
# log_probs = log_probs.unsqueeze(2).expand(
|
||||
# -1, b, top_k_tokens_per_head[i]
|
||||
# ) # b k k'
|
||||
# log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
||||
|
||||
# print("done")
|
||||
# Take only top n best guesses
|
||||
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
||||
# best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
||||
# speculative_logits = all_probs.gather(
|
||||
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
||||
# ).squeeze(0)
|
||||
speculative_logits = all_probs[0]
|
||||
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
|
||||
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
|
||||
return speculative_logits[:, 1:]
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MLPSpeculatorHead(nn.Module):
|
||||
@ -607,6 +607,7 @@ class MLPSpeculatorHead(nn.Module):
|
||||
if input.shape[0] > 128:
|
||||
return logits, None
|
||||
|
||||
input_ids = logits.argmax(dim=-1)
|
||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user