Fixing the scale_weight when users decide to not use the speculation as

much as defined in the config.
This commit is contained in:
Nicolas Patry 2024-08-29 12:33:45 +02:00
parent 62a8343153
commit 09a1de5cd1
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -167,7 +167,7 @@ class MLPSpeculatorModel(torch.nn.Module):
) )
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.emb_weight = math.sqrt(1 - self.state_weight**2) self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU() self.activation = nn.GELU()
# TODO # TODO