diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py index 3884808b..d33b41f3 100644 --- a/server/text_generation_server/layers/mlp.py +++ b/server/text_generation_server/layers/mlp.py @@ -45,6 +45,16 @@ class MLPSpeculatorLayerNorm(nn.Module): return x +INV_SQRT2 = 2**-0.5 + + +def simple_norm(x: torch.Tensor, eps=1e-06): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps) + x = xf.type_as(x) + return x * INV_SQRT2 + + class MLPSpeculatorModelTied(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -74,12 +84,14 @@ class MLPSpeculatorModelTied(torch.nn.Module): # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 - self.emb_weight = math.sqrt(1 - self.state_weight**2) self.activation = nn.GELU() - # TODO self.vsize = config.vocab_size self.inner_dim = config.speculator_config["inner_dim"] self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight def forward( self, @@ -102,7 +114,7 @@ class MLPSpeculatorModelTied(torch.nn.Module): for i in range(self.n_predict): # Project and predict z = self.emb(ind) - z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + # z = z.mul(self.emb_weight) # b k d if i == 0: state = self.proj0(state) * self.state_weight + z else: @@ -168,12 +180,14 @@ class MLPSpeculatorModel(torch.nn.Module): # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1 - self.emb_weight = math.sqrt(1 - self.state_weight**2) self.activation = nn.GELU() - # TODO self.vsize = config.vocab_size self.inner_dim = config.speculator_config["inner_dim"] self.top_k_tokens_per_head = [1] * self.n_predict + self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt( + self.inner_dim / 2 + ) + self.emb.weight *= self.emb_weight def forward( self, @@ -196,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module): for i in range(self.n_predict): # Project and predict z = self.emb[i](ind) - z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + # z = z.mul(self.emb_weight) # b k d state = self.proj[i](state) * self.state_weight + z state = self.activation(self.ln[i](state)) # b k d probs = F.log_softmax(self.head[i](state), dim=-1) # b k v @@ -219,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module): class MLPSpeculatorHead(nn.Module): - def __init__(self, lm_head, mlp_speculator): + def __init__(self, lm_head, mlp_speculator, scale_input: bool): super().__init__() self.lm_head = lm_head self.mlp_speculator = mlp_speculator + self.scale_input = scale_input def forward( self, input: torch.Tensor @@ -233,6 +248,8 @@ class MLPSpeculatorHead(nn.Module): return logits, None input_ids = logits.argmax(dim=-1) + if self.scale_input: + input = simple_norm(input) speculative_logits = self.mlp_speculator(input, input_ids) return logits, speculative_logits @@ -259,5 +276,7 @@ class MLPSpeculatorHead(nn.Module): mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights) else: mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator + scale_input = config.speculator_config.get("scale_input", False) lm_head = TensorParallelHead.load(config, prefix, weights) - return MLPSpeculatorHead(lm_head, mlp_speculator) + return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)