diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py index 35f0bf56..3884808b 100644 --- a/server/text_generation_server/layers/mlp.py +++ b/server/text_generation_server/layers/mlp.py @@ -167,7 +167,7 @@ class MLPSpeculatorModel(torch.nn.Module): ) # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation - self.state_weight = 0.5 ** (0.5 / self.n_predict) + self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 self.emb_weight = math.sqrt(1 - self.state_weight**2) self.activation = nn.GELU() # TODO