mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixed speculator.
This commit is contained in:
parent
9291d42865
commit
1fde6850bb
@ -547,7 +547,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
b = state.size(0)
|
b = state.size(0)
|
||||||
ind = input_ids[-b:].unsqueeze(0)
|
ind = input_ids[-b:].unsqueeze(0)
|
||||||
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
|
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
|
||||||
log_probs = torch.zeros(1, b, device=state.device) # b k
|
# log_probs = torch.zeros(1, b, device=state.device) # b k
|
||||||
all_probs = torch.empty(
|
all_probs = torch.empty(
|
||||||
1, b, self.n_predict, self.vsize, device=state.device
|
1, b, self.n_predict, self.vsize, device=state.device
|
||||||
) # b k h v
|
) # b k h v
|
||||||
@ -561,14 +561,14 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
state = self.activation(self.ln[i](state)) # b k d
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
# Update candidate set with new predictions
|
# Update candidate set with new predictions
|
||||||
out[:, :, i : i + 1] = preds
|
out[:, :, i : i + 1] = preds
|
||||||
|
|
||||||
# Update distribution set with new logits
|
# Update distribution set with new logits
|
||||||
all_probs[:, :, i] = _probs.exp()
|
all_probs[:, :, i] = probs.exp()
|
||||||
|
|
||||||
# Update state, log_probs and ind for new predictions
|
# Update state, log_probs and ind for new predictions
|
||||||
state = state.unsqueeze(2).expand(
|
state = state.unsqueeze(2).expand(
|
||||||
@ -576,21 +576,21 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
) # b k k' d
|
) # b k k' d
|
||||||
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
ind = preds.view(-1, b) # b kk'
|
ind = preds.view(-1, b) # b kk'
|
||||||
log_probs = log_probs.unsqueeze(2).expand(
|
# log_probs = log_probs.unsqueeze(2).expand(
|
||||||
-1, b, top_k_tokens_per_head[i]
|
# -1, b, top_k_tokens_per_head[i]
|
||||||
) # b k k'
|
# ) # b k k'
|
||||||
log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
# log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
||||||
|
|
||||||
# print("done")
|
# print("done")
|
||||||
# Take only top n best guesses
|
# Take only top n best guesses
|
||||||
best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
# best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
||||||
# speculative_logits = all_probs.gather(
|
# speculative_logits = all_probs.gather(
|
||||||
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
||||||
# ).squeeze(0)
|
# ).squeeze(0)
|
||||||
speculative_logits = all_probs[0]
|
speculative_logits = all_probs[0]
|
||||||
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
|
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
|
||||||
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
|
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
|
||||||
return speculative_logits[:, 1:]
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorHead(nn.Module):
|
class MLPSpeculatorHead(nn.Module):
|
||||||
@ -607,6 +607,7 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
if input.shape[0] > 128:
|
if input.shape[0] > 128:
|
||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
|
input_ids = logits.argmax(dim=-1)
|
||||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user